From ecb67bbcef96634343915059d3757054dc75b603 Mon Sep 17 00:00:00 2001 From: Abeeujah Date: Sat, 6 Jun 2026 11:44:18 +0100 Subject: [PATCH 01/13] Rename OrdF64 to PositiveF64 Update the type name to better reflect its domain invariants. The underlying floats are guaranteed to be both positive and not-NaN, making `PositiveF64` a more accurate and descriptive name than `OrdF64`, which only highlighted the `Ord` trait implementation. --- src/policy/compiler.rs | 32 ++++++++++++++++-------------- src/policy/concrete.rs | 44 +++++++++++++++++++++++------------------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index 4ffadb804..99b0a94bf 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -18,20 +18,22 @@ use crate::policy::Concrete; use crate::prelude::*; use crate::{policy, Miniscript, MiniscriptKey, Terminal}; -type PolicyCache = - BTreeMap<(Concrete, OrdF64, Option), BTreeMap>>; +type PolicyCache = BTreeMap< + (Concrete, PositiveF64, Option), + BTreeMap>, +>; /// Ordered f64 for comparison. #[derive(Copy, Clone, PartialEq, Debug)] -pub(crate) struct OrdF64(pub f64); +pub(crate) struct PositiveF64(pub f64); -impl Eq for OrdF64 {} +impl Eq for PositiveF64 {} // We could derive PartialOrd, but we can't derive Ord, and clippy wants us // to derive both or neither. Better to be explicit. -impl PartialOrd for OrdF64 { +impl PartialOrd for PositiveF64 { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for OrdF64 { +impl Ord for PositiveF64 { fn cmp(&self, other: &Self) -> cmp::Ordering { // will panic if given NaN self.0.partial_cmp(&other.0).unwrap() @@ -129,7 +131,7 @@ impl From for CompilerError { } /// Hash required for using OrdF64 as key for hashmap -impl hash::Hash for OrdF64 { +impl hash::Hash for PositiveF64 { fn hash(&self, state: &mut H) { self.0.to_bits().hash(state); } } @@ -150,7 +152,7 @@ struct CompilationKey { /// The probability of dissatisfaction of the compilation of the policy. Note /// that all possible compilations of a (sub)policy have the same sat-prob /// and only differ in dissat_prob. - dissat_prob: Option, + dissat_prob: Option, } impl CompilationKey { @@ -164,7 +166,7 @@ impl CompilationKey { /// Helper to create compilation key from components fn from_type(ty: Type, expensive_verify: bool, dissat_prob: Option) -> Self { - Self { ty, expensive_verify, dissat_prob: dissat_prob.map(OrdF64) } + Self { ty, expensive_verify, dissat_prob: dissat_prob.map(PositiveF64) } } } @@ -787,8 +789,8 @@ where Ctx: ScriptContext, { //Check the cache for hits - let ord_sat_prob = OrdF64(sat_prob); - let ord_dissat_prob = dissat_prob.map(OrdF64); + let ord_sat_prob = PositiveF64(sat_prob); + let ord_dissat_prob = dissat_prob.map(PositiveF64); if let Some(ret) = policy_cache.get(&(policy.clone(), ord_sat_prob, ord_dissat_prob)) { return Ok(ret.clone()); } @@ -1179,10 +1181,10 @@ where best_compilations(policy_cache, policy, sat_prob, dissat_prob)? .into_iter() .filter(|&(key, _)| { - key.ty.corr.base == types::Base::B && key.dissat_prob == dissat_prob.map(OrdF64) + key.ty.corr.base == types::Base::B && key.dissat_prob == dissat_prob.map(PositiveF64) }) .map(|(_, val)| val) - .min_by_key(|ext| OrdF64(ext.cost_1d(sat_prob, dissat_prob))) + .min_by_key(|ext| PositiveF64(ext.cost_1d(sat_prob, dissat_prob))) .ok_or(CompilerError::LimitsExceeded) } @@ -1204,10 +1206,10 @@ where key.ty.corr.base == basic_type && key.ty.corr.unit && val.ms.ty.mall.dissat == types::Dissat::Unique - && key.dissat_prob == dissat_prob.map(OrdF64) + && key.dissat_prob == dissat_prob.map(PositiveF64) }) .map(|(_, val)| val) - .min_by_key(|ext| OrdF64(ext.cost_1d(sat_prob, dissat_prob))) + .min_by_key(|ext| PositiveF64(ext.cost_1d(sat_prob, dissat_prob))) .ok_or(CompilerError::LimitsExceeded) } diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 3e2220df5..e9a2e1beb 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -12,7 +12,7 @@ use bitcoin::absolute; use { crate::descriptor::TapTree, crate::miniscript::ScriptContext, - crate::policy::compiler::{self, CompilerError, OrdF64}, + crate::policy::compiler::{self, CompilerError, PositiveF64}, crate::Descriptor, crate::Miniscript, crate::Tap, @@ -246,7 +246,7 @@ impl Policy { let internal_key = self .tapleaf_probability_iter() .filter_map(|(prob, ref pol)| match pol { - Self::Key(pk) => Some((OrdF64(prob), pk)), + Self::Key(pk) => Some((PositiveF64(prob), pk)), _ => None, }) .max_by_key(|(prob, _)| *prob) @@ -291,14 +291,15 @@ impl Policy { match policy { Self::Trivial => None, policy => { - let mut leaf_compilations: Vec<(OrdF64, Miniscript)> = vec![]; + let mut leaf_compilations: Vec<(PositiveF64, Miniscript)> = + vec![]; for (prob, pol) in policy.tapleaf_probability_iter() { // policy corresponding to the key (replaced by unsatisfiable) is skipped if *pol == Self::Unsatisfiable { continue; } let compilation = compiler::best_compilation::(pol)?; - leaf_compilations.push((OrdF64(prob), compilation)); + leaf_compilations.push((PositiveF64(prob), compilation)); } if !leaf_compilations.is_empty() { let tap_tree = with_huffman_tree::(leaf_compilations); @@ -350,7 +351,7 @@ impl Policy { if n > max_leaves { return Err(CompilerError::TooManyTapleaves { n, max: max_leaves }); } - let mut leaf_compilations: Vec<(OrdF64, Miniscript)> = vec![]; + let mut leaf_compilations: Vec<(PositiveF64, Miniscript)> = vec![]; for (leaf_idx, (prob, pol)) in leaves.iter().enumerate() { if **pol == Self::Unsatisfiable { continue; @@ -361,7 +362,7 @@ impl Policy { leaf_index: leaf_idx, }); } - leaf_compilations.push((OrdF64(*prob), compilation)); + leaf_compilations.push((PositiveF64(*prob), compilation)); } if !leaf_compilations.is_empty() { Some(with_huffman_tree::(leaf_compilations)) @@ -418,7 +419,7 @@ impl Policy { .filter(|x| x.1 != Arc::new(Self::Unsatisfiable)) .map(|(prob, pol)| { ( - OrdF64(prob), + PositiveF64(prob), compiler::best_compilation(pol.as_ref()).unwrap(), ) }) @@ -588,16 +589,16 @@ impl Policy { max_leaves: usize, expand_fn: fn(&Self, f64) -> Vec<(f64, Arc)>, ) -> Vec<(f64, Arc)> { - let mut tapleaf_prob_vec = BTreeSet::<(Reverse, Arc)>::new(); + let mut tapleaf_prob_vec = BTreeSet::<(Reverse, Arc)>::new(); // Store probability corresponding to policy in the enumerated tree. This is required since // owing to the current [policy element enumeration algorithm][`Policy::enumerate_pol`], // two passes of the algorithm might result in same sub-policy showing up. Currently, we // merge the nodes by adding up the corresponding probabilities for the same policy. - let mut pol_prob_map = BTreeMap::, OrdF64>::new(); + let mut pol_prob_map = BTreeMap::, PositiveF64>::new(); let arc_self = Arc::new(self); - tapleaf_prob_vec.insert((Reverse(OrdF64(prob)), Arc::clone(&arc_self))); - pol_prob_map.insert(Arc::clone(&arc_self), OrdF64(prob)); + tapleaf_prob_vec.insert((Reverse(PositiveF64(prob)), Arc::clone(&arc_self))); + pol_prob_map.insert(Arc::clone(&arc_self), PositiveF64(prob)); // Since we know that policy enumeration *must* result in increase in total number of nodes, // we can maintain the length of the ordered set to check if the @@ -612,7 +613,7 @@ impl Policy { // Stopping condition: When NONE of the inputs can be further enumerated. 'outer: loop { //--- FIND a plausible node --- - let mut prob: Reverse = Reverse(OrdF64(0.0)); + let mut prob: Reverse = Reverse(PositiveF64(0.0)); let mut curr_policy: Arc = Arc::new(Self::Unsatisfiable); let mut curr_pol_replace_vec: Vec<(f64, Arc)> = vec![]; let mut no_more_enum = false; @@ -659,7 +660,7 @@ impl Policy { // OPTIMIZATION - Move marked nodes into final vector for (p, pol) in to_del { - assert!(tapleaf_prob_vec.remove(&(Reverse(OrdF64(p)), pol.clone()))); + assert!(tapleaf_prob_vec.remove(&(Reverse(PositiveF64(p)), pol.clone()))); pol_prob_map.remove(&pol); ret.push((p, pol.clone())); } @@ -669,12 +670,13 @@ impl Policy { match pol_prob_map.get(&policy) { Some(prev_prob) => { assert!(tapleaf_prob_vec.remove(&(Reverse(*prev_prob), policy.clone()))); - tapleaf_prob_vec.insert((Reverse(OrdF64(prev_prob.0 + p)), policy.clone())); - pol_prob_map.insert(policy.clone(), OrdF64(prev_prob.0 + p)); + tapleaf_prob_vec + .insert((Reverse(PositiveF64(prev_prob.0 + p)), policy.clone())); + pol_prob_map.insert(policy.clone(), PositiveF64(prev_prob.0 + p)); } None => { - tapleaf_prob_vec.insert((Reverse(OrdF64(p)), policy.clone())); - pol_prob_map.insert(policy.clone(), OrdF64(p)); + tapleaf_prob_vec.insert((Reverse(PositiveF64(p)), policy.clone())); + pol_prob_map.insert(policy.clone(), PositiveF64(p)); } } } @@ -1137,8 +1139,10 @@ fn has_if_fragment(ms: &Miniscript) -> bool { /// Creates a Huffman Tree from compiled [`Miniscript`] nodes. #[cfg(feature = "compiler")] -fn with_huffman_tree(ms: Vec<(OrdF64, Miniscript)>) -> TapTree { - let mut node_weights = BinaryHeap::<(Reverse, TapTree)>::new(); +fn with_huffman_tree( + ms: Vec<(PositiveF64, Miniscript)>, +) -> TapTree { + let mut node_weights = BinaryHeap::<(Reverse, TapTree)>::new(); for (prob, script) in ms { node_weights.push((Reverse(prob), TapTree::leaf(script))); } @@ -1149,7 +1153,7 @@ fn with_huffman_tree(ms: Vec<(OrdF64, Miniscript)>) let p = (p1.0).0 + (p2.0).0; node_weights.push(( - Reverse(OrdF64(p)), + Reverse(PositiveF64(p)), TapTree::combine(s1, s2) .expect("huffman tree cannot produce depth > 128 given sane weights"), )); From 56d820a5268ee964588d052fa3e5e5b1ce1ac916 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 13 Jun 2026 19:57:52 +0000 Subject: [PATCH 02/13] move PositiveF64 to its own module Code move only. Retain the public constructor. We will eliminate its use over the following commits. Also, publicly expose the type in src/lib.rs, since we are going to start enforcing invariants on it and making it generally useful. --- src/lib.rs | 1 + src/policy/compiler.rs | 27 ++------------------------- src/policy/concrete.rs | 4 +++- src/primitives/mod.rs | 1 + src/primitives/positive_f64.rs | 29 +++++++++++++++++++++++++++++ 5 files changed, 36 insertions(+), 26 deletions(-) create mode 100644 src/primitives/positive_f64.rs diff --git a/src/lib.rs b/src/lib.rs index 2d9beaf78..ad6e5f7bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -140,6 +140,7 @@ pub use crate::miniscript::satisfy::{Preimage32, Satisfier}; pub use crate::miniscript::{hash256, Miniscript}; use crate::prelude::*; pub use crate::primitives::absolute_locktime::{AbsLockTime, AbsLockTimeError}; +pub use crate::primitives::positive_f64::PositiveF64; pub use crate::primitives::relative_locktime::{RelLockTime, RelLockTimeError}; pub use crate::primitives::threshold::{Threshold, ThresholdError}; pub use crate::validation::{Error as ValidationError, ValidationParams}; diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index 99b0a94bf..ca61b9ec1 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -5,7 +5,7 @@ //! Optimizing compiler from concrete policies to Miniscript //! -use core::{cmp, f64, fmt, hash, mem}; +use core::{f64, fmt, mem}; #[cfg(feature = "std")] use std::error; @@ -16,30 +16,12 @@ use crate::miniscript::types::{self, ErrorKind, ExtData, Type}; use crate::miniscript::ScriptContext; use crate::policy::Concrete; use crate::prelude::*; -use crate::{policy, Miniscript, MiniscriptKey, Terminal}; +use crate::{policy, Miniscript, MiniscriptKey, PositiveF64, Terminal}; type PolicyCache = BTreeMap< (Concrete, PositiveF64, Option), BTreeMap>, >; - -/// Ordered f64 for comparison. -#[derive(Copy, Clone, PartialEq, Debug)] -pub(crate) struct PositiveF64(pub f64); - -impl Eq for PositiveF64 {} -// We could derive PartialOrd, but we can't derive Ord, and clippy wants us -// to derive both or neither. Better to be explicit. -impl PartialOrd for PositiveF64 { - fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } -} -impl Ord for PositiveF64 { - fn cmp(&self, other: &Self) -> cmp::Ordering { - // will panic if given NaN - self.0.partial_cmp(&other.0).unwrap() - } -} - /// Detailed error type for compiler. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] pub enum CompilerError { @@ -130,11 +112,6 @@ impl From for CompilerError { fn from(e: policy::concrete::PolicyError) -> Self { Self::PolicyError(e) } } -/// Hash required for using OrdF64 as key for hashmap -impl hash::Hash for PositiveF64 { - fn hash(&self, state: &mut H) { self.0.to_bits().hash(state); } -} - /// Compilation key: This represents the state of the best possible compilation /// of a given policy(implicitly keyed). #[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord, Hash)] diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index e9a2e1beb..8dedca6f2 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -12,7 +12,7 @@ use bitcoin::absolute; use { crate::descriptor::TapTree, crate::miniscript::ScriptContext, - crate::policy::compiler::{self, CompilerError, PositiveF64}, + crate::policy::compiler::{self, CompilerError}, crate::Descriptor, crate::Miniscript, crate::Tap, @@ -27,6 +27,8 @@ use crate::prelude::*; use crate::sync::Arc; #[cfg(all(doc, not(feature = "compiler")))] use crate::Descriptor; +#[cfg(feature = "compiler")] +use crate::PositiveF64; use crate::{ AbsLockTime, Error, ForEachKey, FromStrKey, MiniscriptKey, RelLockTime, Threshold, Translator, }; diff --git a/src/primitives/mod.rs b/src/primitives/mod.rs index 919e6c8ba..d7afefdcb 100644 --- a/src/primitives/mod.rs +++ b/src/primitives/mod.rs @@ -12,5 +12,6 @@ //! should be re-exported at the crate root. pub mod absolute_locktime; +pub mod positive_f64; pub mod relative_locktime; pub mod threshold; diff --git a/src/primitives/positive_f64.rs b/src/primitives/positive_f64.rs new file mode 100644 index 000000000..5973ecaa1 --- /dev/null +++ b/src/primitives/positive_f64.rs @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Positive floats ("branch probabilities" for policies) + +use core::{cmp, hash}; + +/// Ordered f64 for comparison. +#[derive(Copy, Clone, PartialEq, Debug)] +pub struct PositiveF64(pub f64); + +impl Eq for PositiveF64 {} + +// We could derive PartialOrd, but we can't derive Ord, and clippy wants us +// to derive both or neither. Better to be explicit. +impl PartialOrd for PositiveF64 { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } +} + +impl Ord for PositiveF64 { + fn cmp(&self, other: &Self) -> cmp::Ordering { + // will panic if given NaN + self.0.partial_cmp(&other.0).unwrap() + } +} + +/// Hash required for using OrdF64 as key for hashmap +impl hash::Hash for PositiveF64 { + fn hash(&self, state: &mut H) { self.0.to_bits().hash(state); } +} From 008864bab0805285c28853411061ceb8ace9b818 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 13 Jun 2026 20:30:35 +0000 Subject: [PATCH 03/13] policy: use NonZeroU32 rather than usize for 'Or' counts Our parser only allows nonzero u32s, so just use the type. This adds a bit of noise to the current code, but when we move the probability computations from f64 to PositiveF64, it will pay dividends. Also eliminates a couple panic paths where we were using .sum or bare addition to add up probabilities; now we always convert to f64 before doing the addition, which will lose precision but not panic if the numbers get too big. --- src/expression/mod.rs | 7 ++++--- src/policy/compiler.rs | 24 ++++++++++++++---------- src/policy/concrete.rs | 29 ++++++++++++++++++----------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 6bfdfd055..21a479526 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -27,6 +27,7 @@ mod error; +use core::num::NonZeroU32; use core::ops; use core::str::FromStr; @@ -679,7 +680,7 @@ impl<'a> Tree<'a> { } /// Parse a string as a u32, forbidding zero. -pub fn parse_num_nonzero(s: &str, context: &'static str) -> Result { +pub fn parse_num_nonzero(s: &str, context: &'static str) -> Result { if s == "0" { return Err(ParseNumError::IllegalZero { context }); } @@ -688,7 +689,7 @@ pub fn parse_num_nonzero(s: &str, context: &'static str) -> Result Result { // Special-case 0 since it is the only number which may start with a leading zero. return Ok(0); } - parse_num_nonzero(s, "") + parse_num_nonzero(s, "").map(u32::from) } #[cfg(test)] diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index ca61b9ec1..64dd3be88 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -841,9 +841,9 @@ where compile_tern!(&mut right, &mut q_zero_left, &mut zero_comp, [1.0, 0.0]); } Concrete::Or(ref subs) => { - let total = (subs[0].0 + subs[1].0) as f64; - let lw = subs[0].0 as f64 / total; - let rw = subs[1].0 as f64 / total; + let total = u32::from(subs[0].0) as f64 + u32::from(subs[1].0) as f64; + let lw = u32::from(subs[0].0) as f64 / total; + let rw = u32::from(subs[1].0) as f64 / total; //and-or if let (Concrete::And(x), _) = (subs[0].1.as_ref(), subs[1].1.as_ref()) { @@ -1192,6 +1192,7 @@ where #[cfg(test)] mod tests { + use core::num::NonZeroU32; use core::str::FromStr; use bitcoin::blockdata::{opcodes, script}; @@ -1207,6 +1208,9 @@ mod tests { type TapAstElemExt = policy::compiler::AstElemExt; type SegwitMiniScript = Miniscript; + #[allow(unsafe_code)] + const ONE: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(1) }; // can be NonZeroU32::MIN in 1.70 + fn pubkeys_and_a_sig(n: usize) -> (Vec, secp256k1::ecdsa::Signature) { let mut ret = Vec::with_capacity(n); let secp = secp256k1::Secp256k1::new(); @@ -1344,14 +1348,14 @@ mod tests { // Liquid policy let policy: BPolicy = Concrete::Or(vec![ ( - 127, + NonZeroU32::new(127).unwrap(), Arc::new(Concrete::Thresh( Threshold::from_iter(3, key_pol[0..5].iter().map(|p| (p.clone()).into())) .unwrap(), )), ), ( - 1, + NonZeroU32::new(1).unwrap(), Arc::new(Concrete::And(vec![ Arc::new(Concrete::Older(RelLockTime::from_height(10000).unwrap())), Arc::new(Concrete::Thresh( @@ -1520,8 +1524,8 @@ mod tests { .collect(); let thresh_res: Result = Concrete::Or(vec![ - (1, Arc::new(Concrete::Thresh(Threshold::and_n(keys_a)))), - (1, Arc::new(Concrete::Thresh(Threshold::and_n(keys_b)))), + (ONE, Arc::new(Concrete::Thresh(Threshold::and_n(keys_a)))), + (ONE, Arc::new(Concrete::Thresh(Threshold::and_n(keys_b)))), ]) .compile(); let script_size = thresh_res.clone().map(|m| m.script_size()); @@ -1591,14 +1595,14 @@ mod tests { // Test that we refuse to compile policies with duplicated keys let (keys, _) = pubkeys_and_a_sig(1); let key = Arc::new(Concrete::Key(keys[0])); - let res = - Concrete::Or(vec![(1, Arc::clone(&key)), (1, Arc::clone(&key))]).compile::(); + let res = Concrete::Or(vec![(ONE, Arc::clone(&key)), (ONE, Arc::clone(&key))]) + .compile::(); assert_eq!( res, Err(CompilerError::PolicyError(policy::concrete::PolicyError::DuplicatePubKeys)) ); // Same for legacy - let res = Concrete::Or(vec![(1, key.clone()), (1, key)]).compile::(); + let res = Concrete::Or(vec![(ONE, key.clone()), (ONE, key)]).compile::(); assert_eq!( res, Err(CompilerError::PolicyError(policy::concrete::PolicyError::DuplicatePubKeys)) diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 8dedca6f2..47cda1f68 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -3,6 +3,7 @@ //! Concrete Policies //! +use core::num::NonZeroU32; use core::{cmp, fmt, str}; #[cfg(feature = "std")] use std::error; @@ -67,7 +68,7 @@ pub enum Policy { And(Vec>), /// A list of sub-policies, one of which must be satisfied, along with /// relative probabilities for each one. - Or(Vec<(usize, Arc)>), + Or(Vec<(NonZeroU32, Arc)>), /// A set of descriptors, satisfactions must be provided for `k` of them. Thresh(Threshold, 0>), } @@ -165,6 +166,12 @@ impl error::Error for PolicyError { } } +/// Sums a series of `NonZeroU32`s by first converting them to floats. +#[cfg(feature = "compiler")] +pub(super) fn sum_nonzero_usizes(iter: impl Iterator) -> f64 { + iter.map(|n| u32::from(n) as f64).sum::() +} + #[cfg(feature = "compiler")] struct TapleafProbabilityIter<'p, Pk: MiniscriptKey> { stack: Vec<(f64, &'p Policy)>, @@ -180,9 +187,9 @@ impl<'p, Pk: MiniscriptKey> Iterator for TapleafProbabilityIter<'p, Pk> { match top { Policy::Or(ref subs) => { - let total_sub_prob = subs.iter().map(|prob_sub| prob_sub.0).sum::(); + let total_sub_prob = sum_nonzero_usizes(subs.iter().map(|prob_sub| prob_sub.0)); for (sub_prob, sub) in subs.iter().rev() { - let ratio = *sub_prob as f64 / total_sub_prob as f64; + let ratio = u32::from(*sub_prob) as f64 / total_sub_prob; self.stack.push((top_prob * ratio, sub)); } } @@ -504,9 +511,9 @@ impl Policy { fn enumerate_pol(&self, prob: f64) -> Vec<(f64, Arc)> { match self { Self::Or(subs) => { - let total_odds = subs.iter().fold(0, |acc, x| acc + x.0); + let total_odds = sum_nonzero_usizes(subs.iter().map(|prob_sub| prob_sub.0)); subs.iter() - .map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone())) + .map(|(odds, pol)| (prob * u32::from(*odds) as f64 / total_odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if thresh.is_or() => { @@ -529,9 +536,9 @@ impl Policy { fn enumerate_pol_native(&self, prob: f64) -> Vec<(f64, Arc)> { match self { Self::Or(subs) => { - let total_odds = subs.iter().fold(0, |acc, x| acc + x.0); + let total_odds = sum_nonzero_usizes(subs.iter().map(|prob_sub| prob_sub.0)); subs.iter() - .map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone())) + .map(|(odds, pol)| (prob * u32::from(*odds) as f64 / total_odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if thresh.is_or() => { @@ -1017,7 +1024,7 @@ impl expression::FromTree for Policy { .map_err(From::from) .map_err(Error::Parse)?; - let mut stack = Vec::<(usize, _)>::with_capacity(128); + let mut stack = Vec::<(NonZeroU32, _)>::with_capacity(128); for node in root.pre_order_iter().rev() { let allow_prob; // Before doing anything else, check if this is the inner value of a terminal. @@ -1050,10 +1057,10 @@ impl expression::FromTree for Policy { }; let frag_prob = match frag_prob { - None => 1, + None => NonZeroU32::new(1).unwrap(), // NonZeroU32::MIN available in Rust 1.70 Some(s) => expression::parse_num_nonzero(s, "fragment probability") .map_err(From::from) - .map_err(Error::Parse)? as usize, + .map_err(Error::Parse)?, }; let new = @@ -1208,7 +1215,7 @@ pub enum TreeChildren<'a, Pk: MiniscriptKey> { /// A conjunction or threshold node's children. And(&'a [Arc>]), /// A disjunction node's children. - Or(&'a [(usize, Arc>)]), + Or(&'a [(NonZeroU32, Arc>)]), } impl<'a, Pk: MiniscriptKey> TreeLike for &'a Policy { From 4c7344c590eaf58fc1fac52d9075673eb6d4e36c Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 13 Jun 2026 22:18:21 +0000 Subject: [PATCH 04/13] positive_f64: add normalized_iter to get an iterator over normalized values Once we remove the unchecked constructor(s) from PositiveF64, we won't be able to do these "sum everything then divide by sum" constructions. Instead, encapsulate them into PositiveF64::normalized_iter. This is mildly annoying to use, but I wasn't able to come up with a nicer solution that avoided gratuitous clones, gave sane compiler errors if you misuse them, and had readable code. --- src/policy/concrete.rs | 28 ++++++-------- src/primitives/positive_f64.rs | 67 ++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 47cda1f68..986b11c84 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -166,12 +166,6 @@ impl error::Error for PolicyError { } } -/// Sums a series of `NonZeroU32`s by first converting them to floats. -#[cfg(feature = "compiler")] -pub(super) fn sum_nonzero_usizes(iter: impl Iterator) -> f64 { - iter.map(|n| u32::from(n) as f64).sum::() -} - #[cfg(feature = "compiler")] struct TapleafProbabilityIter<'p, Pk: MiniscriptKey> { stack: Vec<(f64, &'p Policy)>, @@ -187,10 +181,10 @@ impl<'p, Pk: MiniscriptKey> Iterator for TapleafProbabilityIter<'p, Pk> { match top { Policy::Or(ref subs) => { - let total_sub_prob = sum_nonzero_usizes(subs.iter().map(|prob_sub| prob_sub.0)); - for (sub_prob, sub) in subs.iter().rev() { - let ratio = u32::from(*sub_prob) as f64 / total_sub_prob; - self.stack.push((top_prob * ratio, sub)); + let normalized_iter = + PositiveF64::normalized_iter(subs.iter().map(|x| x.0.into())); + for (ratio, (_, sub)) in normalized_iter.zip(subs.iter()).rev() { + self.stack.push((top_prob * f64::from(ratio), sub)); } } Policy::Thresh(ref thresh) if thresh.is_or() => { @@ -511,9 +505,10 @@ impl Policy { fn enumerate_pol(&self, prob: f64) -> Vec<(f64, Arc)> { match self { Self::Or(subs) => { - let total_odds = sum_nonzero_usizes(subs.iter().map(|prob_sub| prob_sub.0)); - subs.iter() - .map(|(odds, pol)| (prob * u32::from(*odds) as f64 / total_odds, pol.clone())) + let normalized_iter = PositiveF64::normalized_iter(subs.iter().map(|x| x.0.into())); + normalized_iter + .zip(subs.iter()) + .map(|(odds, (_, pol))| (prob * f64::from(odds), pol.clone())) .collect::>() } Self::Thresh(ref thresh) if thresh.is_or() => { @@ -536,9 +531,10 @@ impl Policy { fn enumerate_pol_native(&self, prob: f64) -> Vec<(f64, Arc)> { match self { Self::Or(subs) => { - let total_odds = sum_nonzero_usizes(subs.iter().map(|prob_sub| prob_sub.0)); - subs.iter() - .map(|(odds, pol)| (prob * u32::from(*odds) as f64 / total_odds, pol.clone())) + let normalized_iter = PositiveF64::normalized_iter(subs.iter().map(|x| x.0.into())); + normalized_iter + .zip(subs.iter()) + .map(|(odds, (_, pol))| (prob * f64::from(odds), pol.clone())) .collect::>() } Self::Thresh(ref thresh) if thresh.is_or() => { diff --git a/src/primitives/positive_f64.rs b/src/primitives/positive_f64.rs index 5973ecaa1..f72c91606 100644 --- a/src/primitives/positive_f64.rs +++ b/src/primitives/positive_f64.rs @@ -2,12 +2,34 @@ //! Positive floats ("branch probabilities" for policies) +use core::iter::FusedIterator; +use core::num::NonZeroU32; use core::{cmp, hash}; /// Ordered f64 for comparison. #[derive(Copy, Clone, PartialEq, Debug)] pub struct PositiveF64(pub f64); +impl PositiveF64 { + /// Takes an iterator over [`PositiveF64`] and produces a new iterator where + /// each item is divided so that they all total to 1. + /// + /// On an empty iterator, returns a new empty iterator. + /// + /// Internally clones the iterator and runs it twice, so best to only use + /// this with reference-based iterators obtained with e.g. `slice.iter()` + /// rather than "owning" iterators like you'd get from `vec.into_iter()`. + pub fn normalized_iter(iter: I) -> NormalizedIterator + where + I: Iterator + Clone, + { + // Compute the sum of all the items in the iterator. Because all items in + // the iterator are positive, this will be 0 iff the iterator is empty. + let sum = iter.clone().map(|x| x.0).sum::(); + NormalizedIterator { iter, sum } + } +} + impl Eq for PositiveF64 {} // We could derive PartialOrd, but we can't derive Ord, and clippy wants us @@ -27,3 +49,48 @@ impl Ord for PositiveF64 { impl hash::Hash for PositiveF64 { fn hash(&self, state: &mut H) { self.0.to_bits().hash(state); } } + +impl From for f64 { + fn from(value: PositiveF64) -> Self { value.0 } +} + +impl From for PositiveF64 { + fn from(value: NonZeroU32) -> Self { Self(f64::from(u32::from(value))) } +} + +pub struct NormalizedIterator { + iter: I, + /// Sum must be nonnegative, and may only be zero if `iter` is empty. + sum: f64, +} + +impl Iterator for NormalizedIterator +where + I: Iterator, +{ + type Item = I::Item; + fn next(&mut self) -> Option { + self.iter.next().map(|x| PositiveF64(x.0 / self.sum)) + } + + fn size_hint(&self) -> (usize, Option) { self.iter.size_hint() } +} + +impl DoubleEndedIterator for NormalizedIterator +where + I: Iterator + DoubleEndedIterator, +{ + fn next_back(&mut self) -> Option { + self.iter.next_back().map(|x| PositiveF64(x.0 / self.sum)) + } +} + +impl ExactSizeIterator for NormalizedIterator where + I: Iterator + ExactSizeIterator +{ +} + +impl FusedIterator for NormalizedIterator where + I: Iterator + FusedIterator +{ +} From 8fe2d4a43228127cbccc575f54e844a63aaf4612 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 13 Jun 2026 20:00:37 +0000 Subject: [PATCH 05/13] concrete: use PositiveF64 directly in TapleafProbabilityIter Look ma, no type conversions! --- src/policy/concrete.rs | 14 ++++++------ src/primitives/positive_f64.rs | 41 ++++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 986b11c84..494df9602 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -168,12 +168,12 @@ impl error::Error for PolicyError { #[cfg(feature = "compiler")] struct TapleafProbabilityIter<'p, Pk: MiniscriptKey> { - stack: Vec<(f64, &'p Policy)>, + stack: Vec<(PositiveF64, &'p Policy)>, } #[cfg(feature = "compiler")] impl<'p, Pk: MiniscriptKey> Iterator for TapleafProbabilityIter<'p, Pk> { - type Item = (f64, &'p Policy); + type Item = (PositiveF64, &'p Policy); fn next(&mut self) -> Option { loop { @@ -184,11 +184,11 @@ impl<'p, Pk: MiniscriptKey> Iterator for TapleafProbabilityIter<'p, Pk> { let normalized_iter = PositiveF64::normalized_iter(subs.iter().map(|x| x.0.into())); for (ratio, (_, sub)) in normalized_iter.zip(subs.iter()).rev() { - self.stack.push((top_prob * f64::from(ratio), sub)); + self.stack.push((top_prob * ratio, sub)); } } Policy::Thresh(ref thresh) if thresh.is_or() => { - let n64 = thresh.n() as f64; + let n64 = PositiveF64::n(thresh); for sub in thresh.iter().rev() { self.stack.push((top_prob / n64, sub)); } @@ -240,7 +240,7 @@ impl Policy { /// leaf-nodes to [`MAX_COMPILATION_LEAVES`]. #[cfg(feature = "compiler")] fn tapleaf_probability_iter(&self) -> TapleafProbabilityIter<'_, Pk> { - TapleafProbabilityIter { stack: vec![(1.0, self)] } + TapleafProbabilityIter { stack: vec![(PositiveF64::ONE, self)] } } /// Extracts the internal_key from this policy tree. @@ -249,7 +249,7 @@ impl Policy { let internal_key = self .tapleaf_probability_iter() .filter_map(|(prob, ref pol)| match pol { - Self::Key(pk) => Some((PositiveF64(prob), pk)), + Self::Key(pk) => Some((prob, pk)), _ => None, }) .max_by_key(|(prob, _)| *prob) @@ -302,7 +302,7 @@ impl Policy { continue; } let compilation = compiler::best_compilation::(pol)?; - leaf_compilations.push((PositiveF64(prob), compilation)); + leaf_compilations.push((prob, compilation)); } if !leaf_compilations.is_empty() { let tap_tree = with_huffman_tree::(leaf_compilations); diff --git a/src/primitives/positive_f64.rs b/src/primitives/positive_f64.rs index f72c91606..b2bea9380 100644 --- a/src/primitives/positive_f64.rs +++ b/src/primitives/positive_f64.rs @@ -1,16 +1,20 @@ // SPDX-License-Identifier: CC0-1.0 //! Positive floats ("branch probabilities" for policies) - use core::iter::FusedIterator; use core::num::NonZeroU32; -use core::{cmp, hash}; +use core::{cmp, hash, ops}; + +use crate::Threshold; /// Ordered f64 for comparison. #[derive(Copy, Clone, PartialEq, Debug)] pub struct PositiveF64(pub f64); impl PositiveF64 { + /// The constant one. + pub const ONE: Self = Self(1.0); + /// Takes an iterator over [`PositiveF64`] and produces a new iterator where /// each item is divided so that they all total to 1. /// @@ -28,6 +32,11 @@ impl PositiveF64 { let sum = iter.clone().map(|x| x.0).sum::(); NormalizedIterator { iter, sum } } + + /// The 'n' value of a threshold, as a [`PositiveF64`] + pub fn n(t: &Threshold) -> Self { + Self(t.n() as f64) // cast okay, worst case wil lose precision + } } impl Eq for PositiveF64 {} @@ -58,6 +67,34 @@ impl From for PositiveF64 { fn from(value: NonZeroU32) -> Self { Self(f64::from(u32::from(value))) } } +macro_rules! impl_op { + ($trait:ident, $op:ident, $expr:expr) => { + impl ops::$trait for PositiveF64 { + type Output = Self; + fn $op(self, rhs: Self) -> Self::Output { Self($expr(self.0, rhs.0)) } + } + + impl ops::$trait for &PositiveF64 { + type Output = PositiveF64; + fn $op(self, rhs: Self) -> Self::Output { PositiveF64($expr(self.0, rhs.0)) } + } + + impl ops::$trait<&PositiveF64> for PositiveF64 { + type Output = Self; + fn $op(self, rhs: &PositiveF64) -> Self::Output { Self($expr(self.0, rhs.0)) } + } + + impl ops::$trait for &PositiveF64 { + type Output = PositiveF64; + fn $op(self, rhs: PositiveF64) -> Self::Output { PositiveF64($expr(self.0, rhs.0)) } + } + }; +} + +impl_op!(Add, add, f64::add); +impl_op!(Mul, mul, f64::mul); +impl_op!(Div, div, f64::div); + pub struct NormalizedIterator { iter: I, /// Sum must be nonnegative, and may only be zero if `iter` is empty. From 77577049db51374d39677f54869d2a6310205dfb Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sat, 13 Jun 2026 23:15:37 +0000 Subject: [PATCH 06/13] concrete: directly use PositiveF64 in Policy::enumerate_pol By changing this one function and chasing compiler errors, I found I was able to (and required to) replace a *ton* of f64s with PositiveF64s. Amazingly, this did not require adding any new functionality to the PositiveF64 type, or any conversions. Instead, I was able to *delete* a ton of conversions, wrappers, unwrapping, etc. Now concrete.rs has literally no instances of 'f64' which is awesome. The next step is compiler.rs. --- src/policy/concrete.rs | 81 +++++++++++++++++----------------- src/primitives/positive_f64.rs | 4 ++ 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 494df9602..64f5f5ce8 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -348,8 +348,11 @@ impl Policy { let tap_tree = match policy { Self::Trivial => None, policy => { - let leaves = - policy.enumerate_leaves(1.0, max_leaves, Self::enumerate_pol_native); + let leaves = policy.enumerate_leaves( + PositiveF64::ONE, + max_leaves, + Self::enumerate_pol_native, + ); let n = leaves.len(); if n > max_leaves { return Err(CompilerError::TooManyTapleaves { n, max: max_leaves }); @@ -365,7 +368,7 @@ impl Policy { leaf_index: leaf_idx, }); } - leaf_compilations.push((PositiveF64(*prob), compilation)); + leaf_compilations.push((*prob, compilation)); } if !leaf_compilations.is_empty() { Some(with_huffman_tree::(leaf_compilations)) @@ -417,14 +420,11 @@ impl Policy { Self::Trivial => None, policy => { let leaf_compilations: Vec<_> = policy - .enumerate_policy_tree(1.0) + .enumerate_policy_tree(PositiveF64::ONE) .into_iter() .filter(|x| x.1 != Arc::new(Self::Unsatisfiable)) .map(|(prob, pol)| { - ( - PositiveF64(prob), - compiler::best_compilation(pol.as_ref()).unwrap(), - ) + (prob, compiler::best_compilation(pol.as_ref()).unwrap()) }) .collect(); @@ -502,20 +502,20 @@ impl Policy { /// disjunction over sub-policies output by it. The probability calculations are similar /// to [`Policy::tapleaf_probability_iter`]. #[cfg(feature = "compiler")] - fn enumerate_pol(&self, prob: f64) -> Vec<(f64, Arc)> { + fn enumerate_pol(&self, prob: PositiveF64) -> Vec<(PositiveF64, Arc)> { match self { Self::Or(subs) => { let normalized_iter = PositiveF64::normalized_iter(subs.iter().map(|x| x.0.into())); normalized_iter .zip(subs.iter()) - .map(|(odds, (_, pol))| (prob * f64::from(odds), pol.clone())) + .map(|(odds, (_, pol))| (prob * odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if thresh.is_or() => { - let total_odds = thresh.n(); + let total_odds = PositiveF64::n(thresh); thresh .iter() - .map(|pol| (prob / total_odds as f64, pol.clone())) + .map(|pol| (prob / total_odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if !thresh.is_and() => generate_combination(thresh, prob), @@ -528,26 +528,26 @@ impl Policy { /// This ensures nested `Or` branches inside `And` nodes are decomposed into /// separate sub-policies, producing IF-free leaves for Taptree-native compilation. #[cfg(feature = "compiler")] - fn enumerate_pol_native(&self, prob: f64) -> Vec<(f64, Arc)> { + fn enumerate_pol_native(&self, prob: PositiveF64) -> Vec<(PositiveF64, Arc)> { match self { Self::Or(subs) => { let normalized_iter = PositiveF64::normalized_iter(subs.iter().map(|x| x.0.into())); normalized_iter .zip(subs.iter()) - .map(|(odds, (_, pol))| (prob * f64::from(odds), pol.clone())) + .map(|(odds, (_, pol))| (prob * odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if thresh.is_or() => { - let total_odds = thresh.n(); + let total_odds = PositiveF64::n(thresh); thresh .iter() - .map(|pol| (prob / total_odds as f64, pol.clone())) + .map(|pol| (prob / total_odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if !thresh.is_and() => generate_combination(thresh, prob), Self::And(subs) => { for (i, sub) in subs.iter().enumerate() { - let child_expanded = sub.enumerate_pol_native(1.0); + let child_expanded = sub.enumerate_pol_native(PositiveF64::ONE); if child_expanded.len() > 1 { let other: Vec<_> = subs .iter() @@ -578,7 +578,7 @@ impl Policy { /// set](`BTreeSet`) of `(prob, policy)` (ordered by probability) to maintain the list of /// enumerated sub-policies whose disjunction is isomorphic to initial policy (*invariant*). #[cfg(feature = "compiler")] - fn enumerate_policy_tree(self, prob: f64) -> Vec<(f64, Arc)> { + fn enumerate_policy_tree(self, prob: PositiveF64) -> Vec<(PositiveF64, Arc)> { self.enumerate_leaves(prob, MAX_COMPILATION_LEAVES, Self::enumerate_pol) } @@ -590,10 +590,10 @@ impl Policy { #[allow(clippy::type_complexity)] fn enumerate_leaves( self, - prob: f64, + prob: PositiveF64, max_leaves: usize, - expand_fn: fn(&Self, f64) -> Vec<(f64, Arc)>, - ) -> Vec<(f64, Arc)> { + expand_fn: fn(&Self, PositiveF64) -> Vec<(PositiveF64, Arc)>, + ) -> Vec<(PositiveF64, Arc)> { let mut tapleaf_prob_vec = BTreeSet::<(Reverse, Arc)>::new(); // Store probability corresponding to policy in the enumerated tree. This is required since // owing to the current [policy element enumeration algorithm][`Policy::enumerate_pol`], @@ -602,8 +602,8 @@ impl Policy { let mut pol_prob_map = BTreeMap::, PositiveF64>::new(); let arc_self = Arc::new(self); - tapleaf_prob_vec.insert((Reverse(PositiveF64(prob)), Arc::clone(&arc_self))); - pol_prob_map.insert(Arc::clone(&arc_self), PositiveF64(prob)); + tapleaf_prob_vec.insert((Reverse(prob), Arc::clone(&arc_self))); + pol_prob_map.insert(Arc::clone(&arc_self), prob); // Since we know that policy enumeration *must* result in increase in total number of nodes, // we can maintain the length of the ordered set to check if the @@ -613,21 +613,21 @@ impl Policy { // store the variables let mut enum_len = tapleaf_prob_vec.len(); - let mut ret: Vec<(f64, Arc)> = vec![]; + let mut ret: Vec<(PositiveF64, Arc)> = vec![]; // Stopping condition: When NONE of the inputs can be further enumerated. 'outer: loop { //--- FIND a plausible node --- let mut prob: Reverse = Reverse(PositiveF64(0.0)); let mut curr_policy: Arc = Arc::new(Self::Unsatisfiable); - let mut curr_pol_replace_vec: Vec<(f64, Arc)> = vec![]; + let mut curr_pol_replace_vec: Vec<(PositiveF64, Arc)> = vec![]; let mut no_more_enum = false; // The nodes which can't be enumerated further are directly appended to ret and removed // from the ordered set. - let mut to_del: Vec<(f64, Arc)> = vec![]; + let mut to_del: Vec<(PositiveF64, Arc)> = vec![]; 'inner: for (i, (p, pol)) in tapleaf_prob_vec.iter().enumerate() { - curr_pol_replace_vec = expand_fn(pol, p.0 .0); + curr_pol_replace_vec = expand_fn(pol, p.0); enum_len += curr_pol_replace_vec.len() - 1; // A disjunctive node should have separated this into more nodes assert!(prev_len <= enum_len); @@ -644,14 +644,14 @@ impl Policy { // Either node is enumerable, or we have // Mark all non-enumerable nodes to remove, // if not returning value in the current iteration. - to_del.push((p.0 .0, Arc::clone(pol))); + to_del.push((p.0, Arc::clone(pol))); } } // --- Sanity Checks --- if enum_len > max_leaves || no_more_enum { for (p, pol) in tapleaf_prob_vec.into_iter() { - ret.push((p.0 .0, pol)); + ret.push((p.0, pol)); } break 'outer; } @@ -665,7 +665,7 @@ impl Policy { // OPTIMIZATION - Move marked nodes into final vector for (p, pol) in to_del { - assert!(tapleaf_prob_vec.remove(&(Reverse(PositiveF64(p)), pol.clone()))); + assert!(tapleaf_prob_vec.remove(&(Reverse(p), pol.clone()))); pol_prob_map.remove(&pol); ret.push((p, pol.clone())); } @@ -675,13 +675,12 @@ impl Policy { match pol_prob_map.get(&policy) { Some(prev_prob) => { assert!(tapleaf_prob_vec.remove(&(Reverse(*prev_prob), policy.clone()))); - tapleaf_prob_vec - .insert((Reverse(PositiveF64(prev_prob.0 + p)), policy.clone())); - pol_prob_map.insert(policy.clone(), PositiveF64(prev_prob.0 + p)); + tapleaf_prob_vec.insert((Reverse(prev_prob + p), policy.clone())); + pol_prob_map.insert(policy.clone(), prev_prob + p); } None => { - tapleaf_prob_vec.insert((Reverse(PositiveF64(p)), policy.clone())); - pol_prob_map.insert(policy.clone(), PositiveF64(p)); + tapleaf_prob_vec.insert((Reverse(p), policy.clone())); + pol_prob_map.insert(policy.clone(), p); } } } @@ -1181,12 +1180,12 @@ fn with_huffman_tree( #[cfg(feature = "compiler")] fn generate_combination( thresh: &Threshold>, 0>, - prob: f64, -) -> Vec<(f64, Arc>)> { + prob: PositiveF64, +) -> Vec<(PositiveF64, Arc>)> { debug_assert!(thresh.k() < thresh.n()); - let prob_over_n = prob / thresh.n() as f64; - let mut ret: Vec<(f64, Arc>)> = vec![]; + let prob_over_n = prob / PositiveF64::n(thresh); + let mut ret: Vec<(PositiveF64, Arc>)> = vec![]; for i in 0..thresh.n() { let thresh_less_1 = Threshold::from_iter( thresh.k(), @@ -1288,7 +1287,7 @@ mod compiler_tests { .collect(); let thresh = Threshold::new(2, policies).unwrap(); - let combinations = generate_combination(&thresh, 1.0); + let combinations = generate_combination(&thresh, PositiveF64::ONE); let comb_a: Vec> = vec![ policy_str!("pk(B)"), @@ -1315,7 +1314,7 @@ mod compiler_tests { .map(|sub_pol| { let expected_thresh = Threshold::from_iter(2, sub_pol.into_iter().map(Arc::new)).unwrap(); - (0.25, Arc::new(Policy::Thresh(expected_thresh))) + (PositiveF64::ONE_QUARTER, Arc::new(Policy::Thresh(expected_thresh))) }) .collect::>(); assert_eq!(combinations, expected_comb); diff --git a/src/primitives/positive_f64.rs b/src/primitives/positive_f64.rs index b2bea9380..69013cf96 100644 --- a/src/primitives/positive_f64.rs +++ b/src/primitives/positive_f64.rs @@ -15,6 +15,10 @@ impl PositiveF64 { /// The constant one. pub const ONE: Self = Self(1.0); + /// Constant used in unit tsets + #[cfg(test)] + pub const ONE_QUARTER: Self = Self(0.25); + /// Takes an iterator over [`PositiveF64`] and produces a new iterator where /// each item is divided so that they all total to 1. /// From 3e7f594eea184e39667312e579253f66ea2fa62c Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 14 Jun 2026 14:45:33 +0000 Subject: [PATCH 07/13] compiler: replace AstElemExt::terminal with explicit functions There is a ton of over-abstraction in AstElemeExt. Our original goal was to save on code by having one function, AstElemExt::type_check_common, which would match on a Terminal and do all the "fragment-specific" logic. But the consequence is that this needs to be really general to handle all the data that each fragment might need. In particular, disjunctions need "weights" which are provided by the user to indicate which branches are likely to be taken. No other kind of astelem needs extra data, but all the wrappers need infrastructure to carry this stuff around. By eliminating the terminal() function in place of direct constructors for each terminal type, we simplify the code and are able to also get rid of the type_check function. The next commits will eventually get rid of type_check_common, the branch_prob field of AstElemExt, a bunch of mutability and unwrap paths, and possibly more. --- src/policy/compiler.rs | 112 ++++++++++++++++++++++++++++------------- 1 file changed, 77 insertions(+), 35 deletions(-) diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index 64dd3be88..4f8d9c1df 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -12,6 +12,7 @@ use std::error; use sync::Arc; use crate::miniscript::context::SigType; +use crate::miniscript::limits::{MAX_PUBKEYS_IN_CHECKSIGADD, MAX_PUBKEYS_PER_MULTISIG}; use crate::miniscript::types::{self, ErrorKind, ExtData, Type}; use crate::miniscript::ScriptContext; use crate::policy::Concrete; @@ -194,7 +195,7 @@ impl CompilerExtData { } } - fn multi(k: usize, _n: usize) -> Self { + fn multi(k: usize) -> Self { Self { branch_prob: None, sat_cost: 1.0 + 73.0 * k as f64, @@ -398,16 +399,6 @@ impl CompilerExtData { Self::type_check_common(fragment, get_child) } - /// Compute the type of a fragment. - fn type_check(fragment: &Terminal) -> Self - where - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - let check_child = |sub, _n| Self::type_check(sub); - Self::type_check_common(fragment, check_child) - } - /// Compute the type of a fragment, given a function to look up /// the types of its children, if available and relevant for the /// given fragment @@ -422,7 +413,7 @@ impl CompilerExtData { Terminal::False => Self::FALSE, Terminal::PkK(..) => Self::pk_k::(), Terminal::PkH(..) | Terminal::RawPkH(..) => Self::pk_h::(), - Terminal::Multi(ref thresh) => Self::multi(thresh.k(), thresh.n()), + Terminal::Multi(ref thresh) => Self::multi(thresh.k()), Terminal::SortedMulti(ref thresh) => Self::sortedmulti(thresh.k(), thresh.n()), Terminal::MultiA(ref thresh) => Self::multi_a(thresh.k(), thresh.n()), Terminal::SortedMultiA(ref thresh) => Self::sortedmulti_a(thresh.k(), thresh.n()), @@ -508,8 +499,67 @@ impl AstElemExt { } impl AstElemExt { - fn terminal(ms: Miniscript) -> Self { - Self { comp_ext_data: CompilerExtData::type_check(ms.as_inner()), ms: Arc::new(ms) } + fn unsatisfiable() -> Self { + Self { ms: Arc::new(Miniscript::FALSE), comp_ext_data: CompilerExtData::FALSE } + } + + fn trivial() -> Self { + Self { ms: Arc::new(Miniscript::TRUE), comp_ext_data: CompilerExtData::TRUE } + } + + fn pk_h(key: Pk) -> Self { + Self { + ms: Arc::new(Miniscript::pk_h(key)), + comp_ext_data: CompilerExtData::pk_h::(), + } + } + + fn pk_k(key: Pk) -> Self { + Self { + ms: Arc::new(Miniscript::pk_k(key)), + comp_ext_data: CompilerExtData::pk_k::(), + } + } + + fn after(t: crate::AbsLockTime) -> Self { + Self { ms: Arc::new(Miniscript::after(t)), comp_ext_data: CompilerExtData::time() } + } + + fn older(t: crate::RelLockTime) -> Self { + Self { ms: Arc::new(Miniscript::older(t)), comp_ext_data: CompilerExtData::time() } + } + + fn sha256(h: Pk::Sha256) -> Self { + Self { ms: Arc::new(Miniscript::sha256(h)), comp_ext_data: CompilerExtData::hash() } + } + + fn hash256(h: Pk::Hash256) -> Self { + Self { ms: Arc::new(Miniscript::hash256(h)), comp_ext_data: CompilerExtData::hash() } + } + + fn ripemd160(h: Pk::Ripemd160) -> Self { + Self { ms: Arc::new(Miniscript::ripemd160(h)), comp_ext_data: CompilerExtData::hash() } + } + + fn hash160(h: Pk::Hash160) -> Self { + Self { ms: Arc::new(Miniscript::hash160(h)), comp_ext_data: CompilerExtData::hash() } + } + + fn multi(thresh: crate::Threshold) -> Self { + let k = thresh.k(); + Self { + ms: Arc::new(Miniscript::multi(thresh)), + comp_ext_data: CompilerExtData::multi(k), + } + } + + fn multi_a(thresh: crate::Threshold) -> Self { + let k = thresh.k(); + let n = thresh.n(); + Self { + ms: Arc::new(Miniscript::multi_a(thresh)), + comp_ext_data: CompilerExtData::multi_a(k, n), + } } fn binary(ast: Terminal, l: &Self, r: &Self) -> Result { @@ -793,30 +843,22 @@ where match *policy { Concrete::Unsatisfiable => { - insert_wrap!(AstElemExt::terminal(Miniscript::FALSE)); + insert_wrap!(AstElemExt::unsatisfiable()); } Concrete::Trivial => { - insert_wrap!(AstElemExt::terminal(Miniscript::TRUE)); + insert_wrap!(AstElemExt::trivial()); } Concrete::Key(ref pk) => { - insert_wrap!(AstElemExt::terminal(Miniscript::pk_h(pk.clone()))); - insert_wrap!(AstElemExt::terminal(Miniscript::pk_k(pk.clone()))); - } - Concrete::After(n) => insert_wrap!(AstElemExt::terminal(Miniscript::after(n))), - Concrete::Older(n) => insert_wrap!(AstElemExt::terminal(Miniscript::older(n))), - Concrete::Sha256(ref hash) => { - insert_wrap!(AstElemExt::terminal(Miniscript::sha256(hash.clone()))) + insert_wrap!(AstElemExt::pk_h(pk.clone())); + insert_wrap!(AstElemExt::pk_k(pk.clone())); } + Concrete::After(n) => insert_wrap!(AstElemExt::after(n)), + Concrete::Older(n) => insert_wrap!(AstElemExt::older(n)), + Concrete::Sha256(ref hash) => insert_wrap!(AstElemExt::sha256(hash.clone())), // Satisfaction-cost + script-cost - Concrete::Hash256(ref hash) => { - insert_wrap!(AstElemExt::terminal(Miniscript::hash256(hash.clone()))) - } - Concrete::Ripemd160(ref hash) => { - insert_wrap!(AstElemExt::terminal(Miniscript::ripemd160(hash.clone()))) - } - Concrete::Hash160(ref hash) => { - insert_wrap!(AstElemExt::terminal(Miniscript::hash160(hash.clone()))) - } + Concrete::Hash256(ref hash) => insert_wrap!(AstElemExt::hash256(hash.clone())), + Concrete::Ripemd160(ref hash) => insert_wrap!(AstElemExt::ripemd160(hash.clone())), + Concrete::Hash160(ref hash) => insert_wrap!(AstElemExt::hash160(hash.clone())), Concrete::And(ref subs) => { assert_eq!(subs.len(), 2, "and takes 2 args"); let mut left = @@ -835,7 +877,7 @@ where let mut zero_comp = BTreeMap::new(); zero_comp.insert( CompilationKey::from_type(Type::FALSE, ExtData::FALSE.has_free_verify, dissat_prob), - AstElemExt::terminal(Miniscript::FALSE), + AstElemExt::unsatisfiable(), ); compile_tern!(&mut left, &mut q_zero_right, &mut zero_comp, [1.0, 0.0]); compile_tern!(&mut right, &mut q_zero_left, &mut zero_comp, [1.0, 0.0]); @@ -1022,12 +1064,12 @@ where match Ctx::sig_type() { SigType::Schnorr => { if let Ok(pk_thresh) = pk_thresh.set_maximum() { - insert_wrap!(AstElemExt::terminal(Miniscript::multi_a(pk_thresh))) + insert_wrap!(AstElemExt::multi_a(pk_thresh)) } } SigType::Ecdsa => { if let Ok(pk_thresh) = pk_thresh.set_maximum() { - insert_wrap!(AstElemExt::terminal(Miniscript::multi(pk_thresh))) + insert_wrap!(AstElemExt::multi(pk_thresh)) } } } From 6d94d82f7d31a5241b16a6af7dc5ed677555e061 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 14 Jun 2026 15:00:32 +0000 Subject: [PATCH 08/13] compiler: replace conjuctions in AstElemExt::type_check with explicit functions Conjunctions don't need to set 'branch_prob', but we do set them, leading to confusing logic. By getting rid of these, we can make the conjunction logic simpler. Relatedly, we add an AstElemExt::and_n function rather than trying to force AstElem::and_or to produce an and_n. They're fundamentally different even though they compile to the same sort of script; and_n is a conjunction, can not be dissatisfied, and has no weights associated with it. By separating the two concepts, we eliminate a place where we were setting probabilities to 0.0, which is incorrect. --- src/policy/compiler.rs | 105 ++++++++++++++++++++++++++++++++--------- 1 file changed, 83 insertions(+), 22 deletions(-) diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index 4f8d9c1df..c02b1ba4e 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -13,7 +13,7 @@ use sync::Arc; use crate::miniscript::context::SigType; use crate::miniscript::limits::{MAX_PUBKEYS_IN_CHECKSIGADD, MAX_PUBKEYS_PER_MULTISIG}; -use crate::miniscript::types::{self, ErrorKind, ExtData, Type}; +use crate::miniscript::types::{self, ErrorKind, Type}; use crate::miniscript::ScriptContext; use crate::policy::Concrete; use crate::prelude::*; @@ -280,6 +280,10 @@ impl CompilerExtData { Self { branch_prob: None, sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } } + fn and_n(left: Self, right: Self) -> Self { + Self { branch_prob: None, sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } + } + fn or_b(l: Self, r: Self) -> Self { let lprob = l .branch_prob @@ -353,7 +357,7 @@ impl CompilerExtData { fn and_or(a: Self, b: Self, c: Self) -> Self { let aprob = a.branch_prob.expect("andor, a prob must be set"); let bprob = b.branch_prob.expect("andor, b prob must be set"); - let cprob = c.branch_prob.expect("andor, c prob must be set"); + let cprob = c.branch_prob.unwrap_or(0.0); let adis = a .dissat_cost @@ -562,6 +566,48 @@ impl AstElemExt { } } + /// Helper functions to compose two Miniscript fragments, where we assume + /// by construction that all validation parameters are upheld. + fn compose_typeck_only( + term: Terminal, + ) -> Result>, types::Error> { + let ty = types::Type::type_check(&term)?; + let ext = types::ExtData::type_check(&term); + Ok(Arc::new(Miniscript::from_components_unchecked(term, ty, ext))) + } + + fn and_b(left: &Self, right: &Self) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::AndB( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::and_b(left.comp_ext_data, right.comp_ext_data), + }) + } + + fn and_v(left: &Self, right: &Self) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::AndV( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::and_v(left.comp_ext_data, right.comp_ext_data), + }) + } + + /// and_n(a,b) == andor(a,b,0) is a conjunction of a and b + fn and_n(left: &Self, right: &Self) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::AndOr( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + Arc::new(Miniscript::FALSE), + ))?, + comp_ext_data: CompilerExtData::and_n(left.comp_ext_data, right.comp_ext_data), + }) + } + fn binary(ast: Terminal, l: &Self, r: &Self) -> Result { let lookup_ext = |n| match n { 0 => l.comp_ext_data, @@ -861,26 +907,41 @@ where Concrete::Hash160(ref hash) => insert_wrap!(AstElemExt::hash160(hash.clone())), Concrete::And(ref subs) => { assert_eq!(subs.len(), 2, "and takes 2 args"); - let mut left = - best_compilations(policy_cache, subs[0].as_ref(), sat_prob, dissat_prob)?; - let mut right = - best_compilations(policy_cache, subs[1].as_ref(), sat_prob, dissat_prob)?; - let mut q_zero_right = - best_compilations(policy_cache, subs[1].as_ref(), sat_prob, None)?; - let mut q_zero_left = - best_compilations(policy_cache, subs[0].as_ref(), sat_prob, None)?; - - compile_binary!(&mut left, &mut right, [1.0, 1.0], Terminal::AndB); - compile_binary!(&mut right, &mut left, [1.0, 1.0], Terminal::AndB); - compile_binary!(&mut left, &mut right, [1.0, 1.0], Terminal::AndV); - compile_binary!(&mut right, &mut left, [1.0, 1.0], Terminal::AndV); - let mut zero_comp = BTreeMap::new(); - zero_comp.insert( - CompilationKey::from_type(Type::FALSE, ExtData::FALSE.has_free_verify, dissat_prob), - AstElemExt::unsatisfiable(), - ); - compile_tern!(&mut left, &mut q_zero_right, &mut zero_comp, [1.0, 0.0]); - compile_tern!(&mut right, &mut q_zero_left, &mut zero_comp, [1.0, 0.0]); + let left = best_compilations(policy_cache, subs[0].as_ref(), sat_prob, dissat_prob)?; + let right = best_compilations(policy_cache, subs[1].as_ref(), sat_prob, dissat_prob)?; + let q_zero_right = best_compilations(policy_cache, subs[1].as_ref(), sat_prob, None)?; + let q_zero_left = best_compilations(policy_cache, subs[0].as_ref(), sat_prob, None)?; + + let mut insert_binary = |left: &BTreeMap<_, _>, + right: &BTreeMap<_, _>, + combinator: fn(&_, &_) -> Result<_, _>| + -> Result<(), CompilerError> { + for l in left.values() { + for r in right.values() { + if let Ok(new_ext) = combinator(l, r) { + insert_best_wrapped( + policy_cache, + policy, + &mut ret, + new_ext, + sat_prob, + dissat_prob, + )?; + } + } + } + Ok(()) + }; + insert_binary(&left, &right, AstElemExt::and_b)?; + // Do a separate loop with 'l' and 'r' swapped; we could combine the loops, + // but this would sometimes result in compiling e.g. and(pk(A),pk(B)) into + // an and with A and B swapped, which is surprising to the user since the + // cost is the same with or without the swap. + insert_binary(&right, &left, AstElemExt::and_b)?; + insert_binary(&left, &right, AstElemExt::and_v)?; + insert_binary(&right, &left, AstElemExt::and_v)?; + insert_binary(&left, &q_zero_right, AstElemExt::and_n)?; + insert_binary(&right, &q_zero_left, AstElemExt::and_n)?; } Concrete::Or(ref subs) => { let total = u32::from(subs[0].0) as f64 + u32::from(subs[1].0) as f64; From af6e5bd4dd641b83a87dfab6c6fbd522df427d14 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 14 Jun 2026 18:06:38 +0000 Subject: [PATCH 09/13] compiler: replace disjunctions Had to move the logic into a separate function to avoid stack overflows, because we are recursing very large/complex functions. Other than that I tried to preserve the existing logic (and ran some local fuzztests quite extensively to verify). I removed all the unused code from AstElemExt::type_check, since it would no longer compile -- I changed the disjunction code to take the left/right weights as parameters rather than storing them in the AstElemExt object. The next commit will cause even more code to become dead; where possible, I have tried to prune dead code in separate commits to make the diffs easier to read. The goal of these commits is to remove the `branch_prob` field entirely, which will then let us make AstElemExt immutable, which will unlock some further refactorings. But until I'm done this, I have to just shuffle around the existing logic and the result is pretty big/ugly. --- src/policy/compiler.rs | 435 +++++++++++++++++------------------------ 1 file changed, 181 insertions(+), 254 deletions(-) diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index c02b1ba4e..0adfd95ee 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -5,6 +5,7 @@ //! Optimizing compiler from concrete policies to Miniscript //! +use core::num::NonZeroU32; use core::{f64, fmt, mem}; #[cfg(feature = "std")] use std::error; @@ -89,6 +90,124 @@ impl fmt::Display for CompilerError { } } +fn best_compilations_or( + ret: &mut BTreeMap>, + policy_cache: &mut PolicyCache, + policy: &Concrete, + subs: &[(NonZeroU32, Arc>)], + sat_prob: f64, + dissat_prob: Option, +) -> Result<(), CompilerError> { + macro_rules! compile_tern { + ($a:expr, $b:expr, $c: expr, $w: expr) => { + compile_tern(policy_cache, policy, ret, $a, $b, $c, $w, sat_prob, dissat_prob)? + }; + } + + let total = u32::from(subs[0].0) as f64 + u32::from(subs[1].0) as f64; + let lw = u32::from(subs[0].0) as f64 / total; + let rw = u32::from(subs[1].0) as f64 / total; + + //and-or + if let (Concrete::And(x), _) = (subs[0].1.as_ref(), subs[1].1.as_ref()) { + let mut a1 = best_compilations( + policy_cache, + x[0].as_ref(), + lw * sat_prob, + Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), + )?; + let mut a2 = best_compilations(policy_cache, x[0].as_ref(), lw * sat_prob, None)?; + + let mut b1 = best_compilations( + policy_cache, + x[1].as_ref(), + lw * sat_prob, + Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), + )?; + let mut b2 = best_compilations(policy_cache, x[1].as_ref(), lw * sat_prob, None)?; + + let mut c = + best_compilations(policy_cache, subs[1].1.as_ref(), rw * sat_prob, dissat_prob)?; + + compile_tern!(&mut a1, &mut b2, &mut c, [lw, rw]); + compile_tern!(&mut b1, &mut a2, &mut c, [lw, rw]); + }; + if let (_, Concrete::And(x)) = (&subs[0].1.as_ref(), subs[1].1.as_ref()) { + let mut a1 = best_compilations( + policy_cache, + x[0].as_ref(), + rw * sat_prob, + Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), + )?; + let mut a2 = best_compilations(policy_cache, x[0].as_ref(), rw * sat_prob, None)?; + + let mut b1 = best_compilations( + policy_cache, + x[1].as_ref(), + rw * sat_prob, + Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), + )?; + let mut b2 = best_compilations(policy_cache, x[1].as_ref(), rw * sat_prob, None)?; + + let mut c = + best_compilations(policy_cache, subs[0].1.as_ref(), lw * sat_prob, dissat_prob)?; + + compile_tern!(&mut a1, &mut b2, &mut c, [rw, lw]); + compile_tern!(&mut b1, &mut a2, &mut c, [rw, lw]); + }; + + let dissat_probs = |w: f64| -> Vec> { + vec![ + Some(dissat_prob.unwrap_or(0 as f64) + w * sat_prob), + Some(w * sat_prob), + dissat_prob, + None, + ] + }; + + let mut l_comp = vec![]; + let mut r_comp = vec![]; + + for dissat_prob in dissat_probs(rw).iter() { + let l = best_compilations(policy_cache, subs[0].1.as_ref(), lw * sat_prob, *dissat_prob)?; + l_comp.push(l); + } + + for dissat_prob in dissat_probs(lw).iter() { + let r = best_compilations(policy_cache, subs[1].1.as_ref(), rw * sat_prob, *dissat_prob)?; + r_comp.push(r); + } + + let mut insert_binary = |left: &BTreeMap<_, _>, + right: &BTreeMap<_, _>, + lw: f64, + rw: f64, + combinator: fn(&_, &_, _, _) -> Result<_, _>| + -> Result<(), CompilerError> { + for l in left.values() { + for r in right.values() { + if let Ok(new_ext) = combinator(l, r, lw, rw) { + insert_best_wrapped(policy_cache, policy, ret, new_ext, sat_prob, dissat_prob)?; + } + } + } + Ok(()) + }; + + insert_binary(&l_comp[0], &r_comp[0], lw, rw, AstElemExt::or_b)?; + insert_binary(&r_comp[0], &l_comp[0], rw, lw, AstElemExt::or_b)?; + insert_binary(&l_comp[0], &r_comp[2], lw, rw, AstElemExt::or_d)?; + insert_binary(&r_comp[0], &l_comp[2], rw, lw, AstElemExt::or_d)?; + insert_binary(&l_comp[1], &r_comp[3], lw, rw, AstElemExt::or_c)?; + insert_binary(&r_comp[1], &l_comp[3], rw, lw, AstElemExt::or_c)?; + insert_binary(&l_comp[2], &r_comp[3], lw, rw, AstElemExt::or_i)?; + insert_binary(&r_comp[3], &l_comp[2], rw, lw, AstElemExt::or_i)?; + insert_binary(&l_comp[3], &r_comp[2], lw, rw, AstElemExt::or_i)?; + insert_binary(&r_comp[2], &l_comp[3], rw, lw, AstElemExt::or_i)?; + + Ok(()) +} + #[cfg(feature = "std")] impl error::Error for CompilerError { fn cause(&self) -> Option<&dyn error::Error> { @@ -203,14 +322,6 @@ impl CompilerExtData { } } - fn sortedmulti(k: usize, _n: usize) -> Self { - Self { - branch_prob: None, - sat_cost: 1.0 + 73.0 * k as f64, - dissat_cost: Some(1.0 * (k + 1) as f64), - } - } - fn multi_a(k: usize, n: usize) -> Self { Self { branch_prob: None, @@ -219,8 +330,6 @@ impl CompilerExtData { } } - fn sortedmulti_a(k: usize, n: usize) -> Self { Self::multi_a(k, n) } - fn hash() -> Self { Self { branch_prob: None, sat_cost: 33.0, dissat_cost: Some(33.0) } } fn time() -> Self { Self { branch_prob: None, sat_cost: 0.0, dissat_cost: None } } @@ -284,13 +393,7 @@ impl CompilerExtData { Self { branch_prob: None, sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } } - fn or_b(l: Self, r: Self) -> Self { - let lprob = l - .branch_prob - .expect("BUG: left branch prob must be set for disjunctions"); - let rprob = r - .branch_prob - .expect("BUG: right branch prob must be set for disjunctions"); + fn or_b(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { Self { branch_prob: None, sat_cost: lprob * (l.sat_cost + r.dissat_cost.unwrap()) @@ -299,13 +402,7 @@ impl CompilerExtData { } } - fn or_d(l: Self, r: Self) -> Self { - let lprob = l - .branch_prob - .expect("BUG: left branch prob must be set for disjunctions"); - let rprob = r - .branch_prob - .expect("BUG: right branch prob must be set for disjunctions"); + fn or_d(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { Self { branch_prob: None, sat_cost: lprob * l.sat_cost + rprob * (r.sat_cost + l.dissat_cost.unwrap()), @@ -313,13 +410,7 @@ impl CompilerExtData { } } - fn or_c(l: Self, r: Self) -> Self { - let lprob = l - .branch_prob - .expect("BUG: left branch prob must be set for disjunctions"); - let rprob = r - .branch_prob - .expect("BUG: right branch prob must be set for disjunctions"); + fn or_c(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { Self { branch_prob: None, sat_cost: lprob * l.sat_cost + rprob * (r.sat_cost + l.dissat_cost.unwrap()), @@ -328,13 +419,7 @@ impl CompilerExtData { } #[allow(clippy::manual_map)] // Complex if/let is better as is. - fn or_i(l: Self, r: Self) -> Self { - let lprob = l - .branch_prob - .expect("BUG: left branch prob must be set for disjunctions"); - let rprob = r - .branch_prob - .expect("BUG: right branch prob must be set for disjunctions"); + fn or_i(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { Self { branch_prob: None, sat_cost: lprob * (2.0 + l.sat_cost) + rprob * (1.0 + r.sat_cost), @@ -413,57 +498,6 @@ impl CompilerExtData { Ctx: ScriptContext, { match *fragment { - Terminal::True => Self::TRUE, - Terminal::False => Self::FALSE, - Terminal::PkK(..) => Self::pk_k::(), - Terminal::PkH(..) | Terminal::RawPkH(..) => Self::pk_h::(), - Terminal::Multi(ref thresh) => Self::multi(thresh.k()), - Terminal::SortedMulti(ref thresh) => Self::sortedmulti(thresh.k(), thresh.n()), - Terminal::MultiA(ref thresh) => Self::multi_a(thresh.k(), thresh.n()), - Terminal::SortedMultiA(ref thresh) => Self::sortedmulti_a(thresh.k(), thresh.n()), - Terminal::After(_) => Self::time(), - Terminal::Older(_) => Self::time(), - Terminal::Sha256(..) => Self::hash(), - Terminal::Hash256(..) => Self::hash(), - Terminal::Ripemd160(..) => Self::hash(), - Terminal::Hash160(..) => Self::hash(), - Terminal::Alt(ref sub) => Self::cast_alt(get_child(&sub.node, 0)), - Terminal::Swap(ref sub) => Self::cast_swap(get_child(&sub.node, 0)), - Terminal::Check(ref sub) => Self::cast_check(get_child(&sub.node, 0)), - Terminal::DupIf(ref sub) => Self::cast_dupif(get_child(&sub.node, 0)), - Terminal::Verify(ref sub) => Self::cast_verify(get_child(&sub.node, 0)), - Terminal::NonZero(ref sub) => Self::cast_nonzero(get_child(&sub.node, 0)), - Terminal::ZeroNotEqual(ref sub) => Self::cast_zeronotequal(get_child(&sub.node, 0)), - Terminal::AndB(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::and_b(ltype, rtype) - } - Terminal::AndV(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::and_v(ltype, rtype) - } - Terminal::OrB(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::or_b(ltype, rtype) - } - Terminal::OrD(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::or_d(ltype, rtype) - } - Terminal::OrC(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::or_c(ltype, rtype) - } - Terminal::OrI(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::or_i(ltype, rtype) - } Terminal::AndOr(ref a, ref b, ref c) => { let atype = get_child(&a.node, 0); let btype = get_child(&b.node, 1); @@ -473,6 +507,7 @@ impl CompilerExtData { Terminal::Thresh(ref thresh) => { Self::threshold(thresh.k(), thresh.n(), |n| get_child(&thresh.data()[n].node, n)) } + _ => unreachable!(), } } } @@ -608,20 +643,63 @@ impl AstElemExt { }) } - fn binary(ast: Terminal, l: &Self, r: &Self) -> Result { - let lookup_ext = |n| match n { - 0 => l.comp_ext_data, - 1 => r.comp_ext_data, - _ => unreachable!(), - }; - //Types and ExtData are already cached and stored in children. So, we can - //type_check without cache. For Compiler extra data, we supply a cache. - let ty = types::Type::type_check(&ast)?; - let ext = types::ExtData::type_check(&ast); - let comp_ext_data = CompilerExtData::type_check_with_child(&ast, lookup_ext); + fn or_b(left: &Self, right: &Self, l_weight: f64, r_weight: f64) -> Result { Ok(Self { - ms: Arc::new(Miniscript::from_components_unchecked(ast, ty, ext)), - comp_ext_data, + ms: Self::compose_typeck_only(Terminal::OrB( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::or_b( + left.comp_ext_data, + right.comp_ext_data, + l_weight, + r_weight, + ), + }) + } + + fn or_d(left: &Self, right: &Self, l_weight: f64, r_weight: f64) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::OrD( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::or_d( + left.comp_ext_data, + right.comp_ext_data, + l_weight, + r_weight, + ), + }) + } + + fn or_c(left: &Self, right: &Self, l_weight: f64, r_weight: f64) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::OrC( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::or_c( + left.comp_ext_data, + right.comp_ext_data, + l_weight, + r_weight, + ), + }) + } + + fn or_i(left: &Self, right: &Self, l_weight: f64, r_weight: f64) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::OrI( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::or_i( + left.comp_ext_data, + right.comp_ext_data, + l_weight, + r_weight, + ), }) } @@ -876,16 +954,6 @@ where insert_best_wrapped(policy_cache, policy, &mut ret, $x, sat_prob, dissat_prob)? }; } - macro_rules! compile_binary { - ($l:expr, $r:expr, $w: expr, $f: expr) => { - compile_binary(policy_cache, policy, &mut ret, $l, $r, $w, sat_prob, dissat_prob, $f)? - }; - } - macro_rules! compile_tern { - ($a:expr, $b:expr, $c: expr, $w: expr) => { - compile_tern(policy_cache, policy, &mut ret, $a, $b, $c, $w, sat_prob, dissat_prob)? - }; - } match *policy { Concrete::Unsatisfiable => { @@ -944,113 +1012,7 @@ where insert_binary(&right, &q_zero_left, AstElemExt::and_n)?; } Concrete::Or(ref subs) => { - let total = u32::from(subs[0].0) as f64 + u32::from(subs[1].0) as f64; - let lw = u32::from(subs[0].0) as f64 / total; - let rw = u32::from(subs[1].0) as f64 / total; - - //and-or - if let (Concrete::And(x), _) = (subs[0].1.as_ref(), subs[1].1.as_ref()) { - let mut a1 = best_compilations( - policy_cache, - x[0].as_ref(), - lw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), - )?; - let mut a2 = best_compilations(policy_cache, x[0].as_ref(), lw * sat_prob, None)?; - - let mut b1 = best_compilations( - policy_cache, - x[1].as_ref(), - lw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), - )?; - let mut b2 = best_compilations(policy_cache, x[1].as_ref(), lw * sat_prob, None)?; - - let mut c = best_compilations( - policy_cache, - subs[1].1.as_ref(), - rw * sat_prob, - dissat_prob, - )?; - - compile_tern!(&mut a1, &mut b2, &mut c, [lw, rw]); - compile_tern!(&mut b1, &mut a2, &mut c, [lw, rw]); - }; - if let (_, Concrete::And(x)) = (&subs[0].1.as_ref(), subs[1].1.as_ref()) { - let mut a1 = best_compilations( - policy_cache, - x[0].as_ref(), - rw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), - )?; - let mut a2 = best_compilations(policy_cache, x[0].as_ref(), rw * sat_prob, None)?; - - let mut b1 = best_compilations( - policy_cache, - x[1].as_ref(), - rw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), - )?; - let mut b2 = best_compilations(policy_cache, x[1].as_ref(), rw * sat_prob, None)?; - - let mut c = best_compilations( - policy_cache, - subs[0].1.as_ref(), - lw * sat_prob, - dissat_prob, - )?; - - compile_tern!(&mut a1, &mut b2, &mut c, [rw, lw]); - compile_tern!(&mut b1, &mut a2, &mut c, [rw, lw]); - }; - - let dissat_probs = |w: f64| -> Vec> { - vec![ - Some(dissat_prob.unwrap_or(0 as f64) + w * sat_prob), - Some(w * sat_prob), - dissat_prob, - None, - ] - }; - - let mut l_comp = vec![]; - let mut r_comp = vec![]; - - for dissat_prob in dissat_probs(rw).iter() { - let l = best_compilations( - policy_cache, - subs[0].1.as_ref(), - lw * sat_prob, - *dissat_prob, - )?; - l_comp.push(l); - } - - for dissat_prob in dissat_probs(lw).iter() { - let r = best_compilations( - policy_cache, - subs[1].1.as_ref(), - rw * sat_prob, - *dissat_prob, - )?; - r_comp.push(r); - } - - // or(sha256, pk) - compile_binary!(&mut l_comp[0], &mut r_comp[0], [lw, rw], Terminal::OrB); - compile_binary!(&mut r_comp[0], &mut l_comp[0], [rw, lw], Terminal::OrB); - - compile_binary!(&mut l_comp[0], &mut r_comp[2], [lw, rw], Terminal::OrD); - compile_binary!(&mut r_comp[0], &mut l_comp[2], [rw, lw], Terminal::OrD); - - compile_binary!(&mut l_comp[1], &mut r_comp[3], [lw, rw], Terminal::OrC); - compile_binary!(&mut r_comp[1], &mut l_comp[3], [rw, lw], Terminal::OrC); - - compile_binary!(&mut l_comp[2], &mut r_comp[3], [lw, rw], Terminal::OrI); - compile_binary!(&mut r_comp[2], &mut l_comp[3], [rw, lw], Terminal::OrI); - - compile_binary!(&mut l_comp[3], &mut r_comp[2], [lw, rw], Terminal::OrI); - compile_binary!(&mut r_comp[3], &mut l_comp[2], [rw, lw], Terminal::OrI); + best_compilations_or(&mut ret, policy_cache, policy, subs, sat_prob, dissat_prob)?; } Concrete::Thresh(ref thresh) => { let k = thresh.k(); @@ -1163,41 +1125,6 @@ where } } -/// Helper function to compile different types of binary fragments. -/// `sat_prob` and `dissat_prob` represent the sat and dissat probabilities of -/// root or. `weights` represent the odds for taking each sub branch -#[allow(clippy::too_many_arguments)] -fn compile_binary( - policy_cache: &mut PolicyCache, - policy: &Concrete, - ret: &mut BTreeMap>, - left_comp: &mut BTreeMap>, - right_comp: &mut BTreeMap>, - weights: [f64; 2], - sat_prob: f64, - dissat_prob: Option, - bin_func: F, -) -> Result<(), CompilerError> -where - Pk: MiniscriptKey, - Ctx: ScriptContext, - F: Fn(Arc>, Arc>) -> Terminal, -{ - for l in left_comp.values_mut() { - let lref = Arc::clone(&l.ms); - for r in right_comp.values_mut() { - let rref = Arc::clone(&r.ms); - let ast = bin_func(Arc::clone(&lref), Arc::clone(&rref)); - l.comp_ext_data.branch_prob = Some(weights[0]); - r.comp_ext_data.branch_prob = Some(weights[1]); - if let Ok(new_ext) = AstElemExt::binary(ast, l, r) { - insert_best_wrapped(policy_cache, policy, ret, new_ext, sat_prob, dissat_prob)?; - } - } - } - Ok(()) -} - /// Helper function to compile different order of and_or fragments. /// `sat_prob` and `dissat_prob` represent the sat and dissat probabilities of /// root and_or node. `weights` represent the odds for taking each sub branch From 4f9a0539d8d56a2f8e08411211cf65e2c6749e7e Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 14 Jun 2026 22:58:22 +0000 Subject: [PATCH 10/13] compiler: add AstElemExt::and_or function Update the CompExtData::and_or function to take weights, as for the other disjunctions, eliminating the last place where we read the branch_prob field. This totally eliminates the type_check method, and stops using the branch_prob field of AstElemExt. The next commit will delete a ton of unused stuff, which I didn't do here to minimize the amount of noise in this commit. Then I will start converting to use PositiveF64 rather than f64 everywhere. --- src/policy/compiler.rs | 110 ++++++++++++++++++++++++++--------------- 1 file changed, 69 insertions(+), 41 deletions(-) diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index 0adfd95ee..b46647ae1 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -5,6 +5,9 @@ //! Optimizing compiler from concrete policies to Miniscript //! +#![allow(dead_code)] // will be removed in next commit +#![allow(unused_variables)] // will be removed in next commit + use core::num::NonZeroU32; use core::{f64, fmt, mem}; #[cfg(feature = "std")] @@ -98,62 +101,80 @@ fn best_compilations_or( sat_prob: f64, dissat_prob: Option, ) -> Result<(), CompilerError> { - macro_rules! compile_tern { - ($a:expr, $b:expr, $c: expr, $w: expr) => { - compile_tern(policy_cache, policy, ret, $a, $b, $c, $w, sat_prob, dissat_prob)? - }; - } - let total = u32::from(subs[0].0) as f64 + u32::from(subs[1].0) as f64; let lw = u32::from(subs[0].0) as f64 / total; let rw = u32::from(subs[1].0) as f64 / total; //and-or + let mut insert_ternary = |policy_cache: &mut _, + a: &BTreeMap<_, _>, + b: &BTreeMap<_, _>, + c: &BTreeMap<_, _>, + lw: f64, + rw: f64| + -> Result<(), CompilerError> { + for a in a.values() { + for b in b.values() { + for c in c.values() { + if let Ok(new_ext) = AstElemExt::and_or(a, b, c, lw, rw) { + insert_best_wrapped( + policy_cache, + policy, + ret, + new_ext, + sat_prob, + dissat_prob, + )?; + } + } + } + } + Ok(()) + }; + if let (Concrete::And(x), _) = (subs[0].1.as_ref(), subs[1].1.as_ref()) { - let mut a1 = best_compilations( + let a1 = best_compilations( policy_cache, x[0].as_ref(), lw * sat_prob, Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), )?; - let mut a2 = best_compilations(policy_cache, x[0].as_ref(), lw * sat_prob, None)?; + let a2 = best_compilations(policy_cache, x[0].as_ref(), lw * sat_prob, None)?; - let mut b1 = best_compilations( + let b1 = best_compilations( policy_cache, x[1].as_ref(), lw * sat_prob, Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), )?; - let mut b2 = best_compilations(policy_cache, x[1].as_ref(), lw * sat_prob, None)?; + let b2 = best_compilations(policy_cache, x[1].as_ref(), lw * sat_prob, None)?; - let mut c = - best_compilations(policy_cache, subs[1].1.as_ref(), rw * sat_prob, dissat_prob)?; + let c = best_compilations(policy_cache, subs[1].1.as_ref(), rw * sat_prob, dissat_prob)?; - compile_tern!(&mut a1, &mut b2, &mut c, [lw, rw]); - compile_tern!(&mut b1, &mut a2, &mut c, [lw, rw]); + insert_ternary(policy_cache, &a1, &b2, &c, lw, rw)?; + insert_ternary(policy_cache, &b1, &a2, &c, lw, rw)?; }; if let (_, Concrete::And(x)) = (&subs[0].1.as_ref(), subs[1].1.as_ref()) { - let mut a1 = best_compilations( + let a1 = best_compilations( policy_cache, x[0].as_ref(), rw * sat_prob, Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), )?; - let mut a2 = best_compilations(policy_cache, x[0].as_ref(), rw * sat_prob, None)?; + let a2 = best_compilations(policy_cache, x[0].as_ref(), rw * sat_prob, None)?; - let mut b1 = best_compilations( + let b1 = best_compilations( policy_cache, x[1].as_ref(), rw * sat_prob, Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), )?; - let mut b2 = best_compilations(policy_cache, x[1].as_ref(), rw * sat_prob, None)?; + let b2 = best_compilations(policy_cache, x[1].as_ref(), rw * sat_prob, None)?; - let mut c = - best_compilations(policy_cache, subs[0].1.as_ref(), lw * sat_prob, dissat_prob)?; + let c = best_compilations(policy_cache, subs[0].1.as_ref(), lw * sat_prob, dissat_prob)?; - compile_tern!(&mut a1, &mut b2, &mut c, [rw, lw]); - compile_tern!(&mut b1, &mut a2, &mut c, [rw, lw]); + insert_ternary(policy_cache, &a1, &b2, &c, rw, lw)?; + insert_ternary(policy_cache, &b1, &a2, &c, rw, lw)?; }; let dissat_probs = |w: f64| -> Vec> { @@ -439,18 +460,13 @@ impl CompilerExtData { } } - fn and_or(a: Self, b: Self, c: Self) -> Self { - let aprob = a.branch_prob.expect("andor, a prob must be set"); - let bprob = b.branch_prob.expect("andor, b prob must be set"); - let cprob = c.branch_prob.unwrap_or(0.0); - + fn and_or(a: Self, b: Self, c: Self, lprob: f64, rprob: f64) -> Self { let adis = a .dissat_cost .expect("BUG: and_or first arg(a) must be dissatisfiable"); - debug_assert_eq!(aprob, bprob); //A and B must have same branch prob. Self { branch_prob: None, - sat_cost: aprob * (a.sat_cost + b.sat_cost) + cprob * (adis + c.sat_cost), + sat_cost: lprob * (a.sat_cost + b.sat_cost) + rprob * (adis + c.sat_cost), dissat_cost: c.dissat_cost.map(|cdis| adis + cdis), } } @@ -497,18 +513,7 @@ impl CompilerExtData { Pk: MiniscriptKey, Ctx: ScriptContext, { - match *fragment { - Terminal::AndOr(ref a, ref b, ref c) => { - let atype = get_child(&a.node, 0); - let btype = get_child(&b.node, 1); - let ctype = get_child(&c.node, 2); - Self::and_or(atype, btype, ctype) - } - Terminal::Thresh(ref thresh) => { - Self::threshold(thresh.k(), thresh.n(), |n| get_child(&thresh.data()[n].node, n)) - } - _ => unreachable!(), - } + unreachable!() } } @@ -643,6 +648,29 @@ impl AstElemExt { }) } + fn and_or( + a: &Self, + b: &Self, + c: &Self, + l_weight: f64, + r_weight: f64, + ) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::AndOr( + Arc::clone(&a.ms), + Arc::clone(&b.ms), + Arc::clone(&c.ms), + ))?, + comp_ext_data: CompilerExtData::and_or( + a.comp_ext_data, + b.comp_ext_data, + c.comp_ext_data, + l_weight, + r_weight, + ), + }) + } + fn or_b(left: &Self, right: &Self, l_weight: f64, r_weight: f64) -> Result { Ok(Self { ms: Self::compose_typeck_only(Terminal::OrB( From 0c78c9b42266492b9cb56f09e4135a359ba0843b Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 14 Jun 2026 23:08:59 +0000 Subject: [PATCH 11/13] compiler: delete a ton of unused code --- src/policy/compiler.rs | 149 +++++------------------------------------ 1 file changed, 17 insertions(+), 132 deletions(-) diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index b46647ae1..ff97bc7de 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -5,9 +5,6 @@ //! Optimizing compiler from concrete policies to Miniscript //! -#![allow(dead_code)] // will be removed in next commit -#![allow(unused_variables)] // will be removed in next commit - use core::num::NonZeroU32; use core::{f64, fmt, mem}; #[cfg(feature = "std")] @@ -290,10 +287,6 @@ impl CompilationKey { #[derive(Copy, Clone, Debug)] struct CompilerExtData { - /// If this node is the direct child of a disjunction, this field must - /// have the probability of its branch being taken. Otherwise it is ignored. - /// All functions initialize it to `None`. - branch_prob: Option, /// The number of bytes needed to satisfy the fragment in segwit format /// (total length of all witness pushes, plus their own length prefixes) sat_cost: f64, @@ -304,13 +297,12 @@ struct CompilerExtData { } impl CompilerExtData { - const TRUE: Self = Self { branch_prob: None, sat_cost: 0.0, dissat_cost: None }; + const TRUE: Self = Self { sat_cost: 0.0, dissat_cost: None }; - const FALSE: Self = Self { branch_prob: None, sat_cost: f64::MAX, dissat_cost: Some(0.0) }; + const FALSE: Self = Self { sat_cost: f64::MAX, dissat_cost: Some(0.0) }; fn pk_k() -> Self { Self { - branch_prob: None, sat_cost: match Ctx::sig_type() { SigType::Ecdsa => 73.0, SigType::Schnorr => 1.0 /* */ + 64.0 /* sig */ + 1.0, /* */ @@ -321,7 +313,6 @@ impl CompilerExtData { fn pk_h() -> Self { Self { - branch_prob: None, sat_cost: match Ctx::sig_type() { SigType::Ecdsa => 73.0 + 34.0, SigType::Schnorr => 66.0 + 33.0, @@ -336,68 +327,46 @@ impl CompilerExtData { } fn multi(k: usize) -> Self { - Self { - branch_prob: None, - sat_cost: 1.0 + 73.0 * k as f64, - dissat_cost: Some(1.0 * (k + 1) as f64), - } + Self { sat_cost: 1.0 + 73.0 * k as f64, dissat_cost: Some(1.0 * (k + 1) as f64) } } fn multi_a(k: usize, n: usize) -> Self { Self { - branch_prob: None, sat_cost: 66.0 * k as f64 + (n - k) as f64, dissat_cost: Some(n as f64), /* ... := 0x00 ... 0x00 (n times) */ } } - fn hash() -> Self { Self { branch_prob: None, sat_cost: 33.0, dissat_cost: Some(33.0) } } + fn hash() -> Self { Self { sat_cost: 33.0, dissat_cost: Some(33.0) } } - fn time() -> Self { Self { branch_prob: None, sat_cost: 0.0, dissat_cost: None } } + fn time() -> Self { Self { sat_cost: 0.0, dissat_cost: None } } - fn cast_alt(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } - } + fn cast_alt(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } } - fn cast_swap(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } - } + fn cast_swap(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } } - fn cast_check(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } - } + fn cast_check(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } } - fn cast_dupif(self) -> Self { - Self { branch_prob: None, sat_cost: 2.0 + self.sat_cost, dissat_cost: Some(1.0) } - } + fn cast_dupif(self) -> Self { Self { sat_cost: 2.0 + self.sat_cost, dissat_cost: Some(1.0) } } - fn cast_verify(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: None } - } + fn cast_verify(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: None } } - fn cast_nonzero(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: Some(1.0) } - } + fn cast_nonzero(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: Some(1.0) } } fn cast_zeronotequal(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } + Self { sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } } - fn cast_true(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: None } - } + fn cast_true(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: None } } fn cast_unlikely(self) -> Self { - Self { branch_prob: None, sat_cost: 2.0 + self.sat_cost, dissat_cost: Some(1.0) } + Self { sat_cost: 2.0 + self.sat_cost, dissat_cost: Some(1.0) } } - fn cast_likely(self) -> Self { - Self { branch_prob: None, sat_cost: 1.0 + self.sat_cost, dissat_cost: Some(2.0) } - } + fn cast_likely(self) -> Self { Self { sat_cost: 1.0 + self.sat_cost, dissat_cost: Some(2.0) } } fn and_b(left: Self, right: Self) -> Self { Self { - branch_prob: None, sat_cost: left.sat_cost + right.sat_cost, dissat_cost: match (left.dissat_cost, right.dissat_cost) { (Some(l), Some(r)) => Some(l + r), @@ -407,16 +376,15 @@ impl CompilerExtData { } fn and_v(left: Self, right: Self) -> Self { - Self { branch_prob: None, sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } + Self { sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } } fn and_n(left: Self, right: Self) -> Self { - Self { branch_prob: None, sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } + Self { sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } } fn or_b(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { Self { - branch_prob: None, sat_cost: lprob * (l.sat_cost + r.dissat_cost.unwrap()) + rprob * (r.sat_cost + l.dissat_cost.unwrap()), dissat_cost: Some(l.dissat_cost.unwrap() + r.dissat_cost.unwrap()), @@ -425,7 +393,6 @@ impl CompilerExtData { fn or_d(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { Self { - branch_prob: None, sat_cost: lprob * l.sat_cost + rprob * (r.sat_cost + l.dissat_cost.unwrap()), dissat_cost: r.dissat_cost.map(|rd| l.dissat_cost.unwrap() + rd), } @@ -433,7 +400,6 @@ impl CompilerExtData { fn or_c(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { Self { - branch_prob: None, sat_cost: lprob * l.sat_cost + rprob * (r.sat_cost + l.dissat_cost.unwrap()), dissat_cost: None, } @@ -442,7 +408,6 @@ impl CompilerExtData { #[allow(clippy::manual_map)] // Complex if/let is better as is. fn or_i(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { Self { - branch_prob: None, sat_cost: lprob * (2.0 + l.sat_cost) + rprob * (1.0 + r.sat_cost), dissat_cost: if let (Some(ldis), Some(rdis)) = (l.dissat_cost, r.dissat_cost) { if (2.0 + ldis) > (1.0 + rdis) { @@ -465,7 +430,6 @@ impl CompilerExtData { .dissat_cost .expect("BUG: and_or first arg(a) must be dissatisfiable"); Self { - branch_prob: None, sat_cost: lprob * (a.sat_cost + b.sat_cost) + rprob * (adis + c.sat_cost), dissat_cost: c.dissat_cost.map(|cdis| adis + cdis), } @@ -484,39 +448,12 @@ impl CompilerExtData { dissat_cost += sub.dissat_cost.unwrap(); } Self { - branch_prob: None, sat_cost: sat_cost * k_over_n + dissat_cost * (1.0 - k_over_n), dissat_cost: Some(dissat_cost), } } } -impl CompilerExtData { - /// Compute the type of a fragment, given a function to look up - /// the types of its children. - fn type_check_with_child(fragment: &Terminal, child: C) -> Self - where - C: Fn(usize) -> Self, - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - let get_child = |_sub, n| child(n); - Self::type_check_common(fragment, get_child) - } - - /// Compute the type of a fragment, given a function to look up - /// the types of its children, if available and relevant for the - /// given fragment - fn type_check_common<'a, Pk, Ctx, C>(fragment: &'a Terminal, get_child: C) -> Self - where - C: Fn(&'a Terminal, usize) -> Self, - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - unreachable!() - } -} - /// Miniscript AST fragment with additional data needed by the compiler #[derive(Clone, Debug)] struct AstElemExt { @@ -730,24 +667,6 @@ impl AstElemExt { ), }) } - - fn ternary(ast: Terminal, a: &Self, b: &Self, c: &Self) -> Result { - let lookup_ext = |n| match n { - 0 => a.comp_ext_data, - 1 => b.comp_ext_data, - 2 => c.comp_ext_data, - _ => unreachable!(), - }; - //Types and ExtData are already cached and stored in children. So, we can - //type_check without cache. For Compiler extra data, we supply a cache. - let ty = types::Type::type_check(&ast)?; - let ext = types::ExtData::type_check(&ast); - let comp_ext_data = CompilerExtData::type_check_with_child(&ast, lookup_ext); - Ok(Self { - ms: Arc::new(Miniscript::from_components_unchecked(ast, ty, ext)), - comp_ext_data, - }) - } } /// Different types of casts possible for each node. @@ -1153,40 +1072,6 @@ where } } -/// Helper function to compile different order of and_or fragments. -/// `sat_prob` and `dissat_prob` represent the sat and dissat probabilities of -/// root and_or node. `weights` represent the odds for taking each sub branch -#[allow(clippy::too_many_arguments)] -fn compile_tern( - policy_cache: &mut PolicyCache, - policy: &Concrete, - ret: &mut BTreeMap>, - a_comp: &mut BTreeMap>, - b_comp: &mut BTreeMap>, - c_comp: &mut BTreeMap>, - weights: [f64; 2], - sat_prob: f64, - dissat_prob: Option, -) -> Result<(), CompilerError> { - for a in a_comp.values_mut() { - let aref = Arc::clone(&a.ms); - for b in b_comp.values_mut() { - let bref = Arc::clone(&b.ms); - for c in c_comp.values_mut() { - let cref = Arc::clone(&c.ms); - let ast = Terminal::AndOr(Arc::clone(&aref), Arc::clone(&bref), Arc::clone(&cref)); - a.comp_ext_data.branch_prob = Some(weights[0]); - b.comp_ext_data.branch_prob = Some(weights[0]); - c.comp_ext_data.branch_prob = Some(weights[1]); - if let Ok(new_ext) = AstElemExt::ternary(ast, a, b, c) { - insert_best_wrapped(policy_cache, policy, ret, new_ext, sat_prob, dissat_prob)?; - } - } - } - } - Ok(()) -} - /// Obtain the best compilation of for p=1.0 and q=0 pub fn best_compilation( policy: &Concrete, From b6946afd9c734081129db746d816877efb2c8c94 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Sun, 14 Jun 2026 23:15:34 +0000 Subject: [PATCH 12/13] compiler: use PositiveF64 for all probabilites Costs remain using f64, since they may be zero (e.g. the satisfactio cost of a timelock). But all probilities in the compiler should be positive: if they are zero, the corresponding path is impossible and we represent that with an option. --- src/policy/compiler.rs | 165 ++++++++++++++++++++------------- src/primitives/positive_f64.rs | 26 ++++++ 2 files changed, 124 insertions(+), 67 deletions(-) diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index ff97bc7de..cf3b72166 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -95,20 +95,20 @@ fn best_compilations_or( policy_cache: &mut PolicyCache, policy: &Concrete, subs: &[(NonZeroU32, Arc>)], - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result<(), CompilerError> { - let total = u32::from(subs[0].0) as f64 + u32::from(subs[1].0) as f64; - let lw = u32::from(subs[0].0) as f64 / total; - let rw = u32::from(subs[1].0) as f64 / total; + let total = PositiveF64::from(subs[0].0) + PositiveF64::from(subs[1].0); + let lw = PositiveF64::from(subs[0].0) / total; + let rw = PositiveF64::from(subs[1].0) / total; //and-or let mut insert_ternary = |policy_cache: &mut _, a: &BTreeMap<_, _>, b: &BTreeMap<_, _>, c: &BTreeMap<_, _>, - lw: f64, - rw: f64| + lw: PositiveF64, + rw: PositiveF64| -> Result<(), CompilerError> { for a in a.values() { for b in b.values() { @@ -134,7 +134,7 @@ fn best_compilations_or( policy_cache, x[0].as_ref(), lw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), + Some((rw * sat_prob).conditional_add(dissat_prob)), )?; let a2 = best_compilations(policy_cache, x[0].as_ref(), lw * sat_prob, None)?; @@ -142,7 +142,7 @@ fn best_compilations_or( policy_cache, x[1].as_ref(), lw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), + Some((rw * sat_prob).conditional_add(dissat_prob)), )?; let b2 = best_compilations(policy_cache, x[1].as_ref(), lw * sat_prob, None)?; @@ -156,7 +156,7 @@ fn best_compilations_or( policy_cache, x[0].as_ref(), rw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), + Some((lw * sat_prob).conditional_add(dissat_prob)), )?; let a2 = best_compilations(policy_cache, x[0].as_ref(), rw * sat_prob, None)?; @@ -164,7 +164,7 @@ fn best_compilations_or( policy_cache, x[1].as_ref(), rw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), + Some((lw * sat_prob).conditional_add(dissat_prob)), )?; let b2 = best_compilations(policy_cache, x[1].as_ref(), rw * sat_prob, None)?; @@ -174,9 +174,9 @@ fn best_compilations_or( insert_ternary(policy_cache, &b1, &a2, &c, rw, lw)?; }; - let dissat_probs = |w: f64| -> Vec> { + let dissat_probs = |w: PositiveF64| -> Vec> { vec![ - Some(dissat_prob.unwrap_or(0 as f64) + w * sat_prob), + Some((w * sat_prob).conditional_add(dissat_prob)), Some(w * sat_prob), dissat_prob, None, @@ -198,8 +198,8 @@ fn best_compilations_or( let mut insert_binary = |left: &BTreeMap<_, _>, right: &BTreeMap<_, _>, - lw: f64, - rw: f64, + lw: PositiveF64, + rw: PositiveF64, combinator: fn(&_, &_, _, _) -> Result<_, _>| -> Result<(), CompilerError> { for l in left.values() { @@ -280,8 +280,8 @@ impl CompilationKey { } /// Helper to create compilation key from components - fn from_type(ty: Type, expensive_verify: bool, dissat_prob: Option) -> Self { - Self { ty, expensive_verify, dissat_prob: dissat_prob.map(PositiveF64) } + fn from_type(ty: Type, expensive_verify: bool, dissat_prob: Option) -> Self { + Self { ty, expensive_verify, dissat_prob } } } @@ -383,32 +383,34 @@ impl CompilerExtData { Self { sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } } - fn or_b(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { + fn or_b(l: Self, r: Self, lprob: PositiveF64, rprob: PositiveF64) -> Self { Self { - sat_cost: lprob * (l.sat_cost + r.dissat_cost.unwrap()) - + rprob * (r.sat_cost + l.dissat_cost.unwrap()), + sat_cost: f64::from(lprob) * (l.sat_cost + r.dissat_cost.unwrap()) + + f64::from(rprob) * (r.sat_cost + l.dissat_cost.unwrap()), dissat_cost: Some(l.dissat_cost.unwrap() + r.dissat_cost.unwrap()), } } - fn or_d(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { + fn or_d(l: Self, r: Self, lprob: PositiveF64, rprob: PositiveF64) -> Self { Self { - sat_cost: lprob * l.sat_cost + rprob * (r.sat_cost + l.dissat_cost.unwrap()), + sat_cost: f64::from(lprob) * l.sat_cost + + f64::from(rprob) * (r.sat_cost + l.dissat_cost.unwrap()), dissat_cost: r.dissat_cost.map(|rd| l.dissat_cost.unwrap() + rd), } } - fn or_c(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { + fn or_c(l: Self, r: Self, lprob: PositiveF64, rprob: PositiveF64) -> Self { Self { - sat_cost: lprob * l.sat_cost + rprob * (r.sat_cost + l.dissat_cost.unwrap()), + sat_cost: f64::from(lprob) * l.sat_cost + + f64::from(rprob) * (r.sat_cost + l.dissat_cost.unwrap()), dissat_cost: None, } } #[allow(clippy::manual_map)] // Complex if/let is better as is. - fn or_i(l: Self, r: Self, lprob: f64, rprob: f64) -> Self { + fn or_i(l: Self, r: Self, lprob: PositiveF64, rprob: PositiveF64) -> Self { Self { - sat_cost: lprob * (2.0 + l.sat_cost) + rprob * (1.0 + r.sat_cost), + sat_cost: f64::from(lprob) * (2.0 + l.sat_cost) + f64::from(rprob) * (1.0 + r.sat_cost), dissat_cost: if let (Some(ldis), Some(rdis)) = (l.dissat_cost, r.dissat_cost) { if (2.0 + ldis) > (1.0 + rdis) { Some(1.0 + rdis) @@ -425,12 +427,13 @@ impl CompilerExtData { } } - fn and_or(a: Self, b: Self, c: Self, lprob: f64, rprob: f64) -> Self { + fn and_or(a: Self, b: Self, c: Self, lprob: PositiveF64, rprob: PositiveF64) -> Self { let adis = a .dissat_cost .expect("BUG: and_or first arg(a) must be dissatisfiable"); Self { - sat_cost: lprob * (a.sat_cost + b.sat_cost) + rprob * (adis + c.sat_cost), + sat_cost: f64::from(lprob) * (a.sat_cost + b.sat_cost) + + f64::from(rprob) * (adis + c.sat_cost), dissat_cost: c.dissat_cost.map(|cdis| adis + cdis), } } @@ -467,11 +470,11 @@ impl AstElemExt { /// Compute a 1-dimensional cost, given a probability of satisfaction /// and a probability of dissatisfaction; if `dissat_prob` is `None` /// then it is assumed that dissatisfaction never occurs - fn cost_1d(&self, sat_prob: f64, dissat_prob: Option) -> f64 { + fn cost_1d(&self, sat_prob: PositiveF64, dissat_prob: Option) -> f64 { self.ms.ext.pk_cost as f64 - + self.comp_ext_data.sat_cost * sat_prob + + self.comp_ext_data.sat_cost * f64::from(sat_prob) + match (dissat_prob, self.comp_ext_data.dissat_cost) { - (Some(prob), Some(cost)) => prob * cost, + (Some(prob), Some(cost)) => f64::from(prob) * cost, (Some(_), None) => f64::INFINITY, (None, Some(_)) => 0.0, (None, None) => 0.0, @@ -589,8 +592,8 @@ impl AstElemExt { a: &Self, b: &Self, c: &Self, - l_weight: f64, - r_weight: f64, + l_weight: PositiveF64, + r_weight: PositiveF64, ) -> Result { Ok(Self { ms: Self::compose_typeck_only(Terminal::AndOr( @@ -608,7 +611,12 @@ impl AstElemExt { }) } - fn or_b(left: &Self, right: &Self, l_weight: f64, r_weight: f64) -> Result { + fn or_b( + left: &Self, + right: &Self, + l_weight: PositiveF64, + r_weight: PositiveF64, + ) -> Result { Ok(Self { ms: Self::compose_typeck_only(Terminal::OrB( Arc::clone(&left.ms), @@ -623,7 +631,12 @@ impl AstElemExt { }) } - fn or_d(left: &Self, right: &Self, l_weight: f64, r_weight: f64) -> Result { + fn or_d( + left: &Self, + right: &Self, + l_weight: PositiveF64, + r_weight: PositiveF64, + ) -> Result { Ok(Self { ms: Self::compose_typeck_only(Terminal::OrD( Arc::clone(&left.ms), @@ -638,7 +651,12 @@ impl AstElemExt { }) } - fn or_c(left: &Self, right: &Self, l_weight: f64, r_weight: f64) -> Result { + fn or_c( + left: &Self, + right: &Self, + l_weight: PositiveF64, + r_weight: PositiveF64, + ) -> Result { Ok(Self { ms: Self::compose_typeck_only(Terminal::OrC( Arc::clone(&left.ms), @@ -653,7 +671,12 @@ impl AstElemExt { }) } - fn or_i(left: &Self, right: &Self, l_weight: f64, r_weight: f64) -> Result { + fn or_i( + left: &Self, + right: &Self, + l_weight: PositiveF64, + r_weight: PositiveF64, + ) -> Result { Ok(Self { ms: Self::compose_typeck_only(Terminal::OrI( Arc::clone(&left.ms), @@ -768,8 +791,8 @@ fn all_casts() -> [Cast; 10] { fn insert_elem( map: &mut BTreeMap>, elem: AstElemExt, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> bool { // We check before compiling that non-malleable satisfactions exist, and it appears that // there are no cases when malleable satisfactions beat non-malleable ones (and if there @@ -819,8 +842,8 @@ fn insert_elem( fn insert_elem_closure( map: &mut BTreeMap>, astelem_ext: AstElemExt, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) { let mut cast_stack: VecDeque> = VecDeque::new(); if insert_elem(map, astelem_ext.clone(), sat_prob, dissat_prob) { @@ -855,8 +878,8 @@ fn insert_best_wrapped( policy: &Concrete, map: &mut BTreeMap>, data: AstElemExt, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result<(), CompilerError> { insert_elem_closure(map, data, sat_prob, dissat_prob); @@ -879,17 +902,15 @@ fn insert_best_wrapped( fn best_compilations( policy_cache: &mut PolicyCache, policy: &Concrete, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result>, CompilerError> where Pk: MiniscriptKey, Ctx: ScriptContext, { //Check the cache for hits - let ord_sat_prob = PositiveF64(sat_prob); - let ord_dissat_prob = dissat_prob.map(PositiveF64); - if let Some(ret) = policy_cache.get(&(policy.clone(), ord_sat_prob, ord_dissat_prob)) { + if let Some(ret) = policy_cache.get(&(policy.clone(), sat_prob, dissat_prob)) { return Ok(ret.clone()); } @@ -964,7 +985,7 @@ where Concrete::Thresh(ref thresh) => { let k = thresh.k(); let n = thresh.n(); - let k_over_n = k as f64 / n as f64; + let k_over_n = PositiveF64::k_over_n(thresh); let mut sub_ext_data = Vec::with_capacity(n); @@ -972,10 +993,20 @@ where let mut best_ws = Vec::with_capacity(n); let mut min_value = (0, f64::INFINITY); + + let total_sat_prob = sat_prob * k_over_n; + // This match can be written in terms of nested conditional_adds() but seems less clear that way. + let total_dissat_prob = match (dissat_prob, PositiveF64::one_minus_k_over_n(thresh)) { + (Some(dp), Some(kn)) => Some(dp + kn * sat_prob), + (Some(dp), None) => Some(dp), + (None, Some(kn)) => Some(kn * sat_prob), + (None, None) => None, + }; + for (i, ast) in thresh.iter().enumerate() { - let sp = sat_prob * k_over_n; - //Expressions must be dissatisfiable - let dp = Some(dissat_prob.unwrap_or(0 as f64) + (1.0 - k_over_n) * sat_prob); + let sp = total_sat_prob; + let dp = total_dissat_prob; + let be = best(types::Base::B, policy_cache, ast.as_ref(), sp, dp)?; let bw = best(types::Base::W, policy_cache, ast.as_ref(), sp, dp)?; @@ -1056,7 +1087,7 @@ where } } for k in ret.keys() { - debug_assert_eq!(k.dissat_prob, ord_dissat_prob); + debug_assert_eq!(k.dissat_prob, dissat_prob); } if ret.is_empty() { // The only reason we are discarding elements out of compiler is because @@ -1067,7 +1098,7 @@ where // before calling this compile function Err(CompilerError::LimitsExceeded) } else { - policy_cache.insert((policy.clone(), ord_sat_prob, ord_dissat_prob), ret.clone()); + policy_cache.insert((policy.clone(), sat_prob, dissat_prob), ret.clone()); Ok(ret) } } @@ -1077,7 +1108,7 @@ pub fn best_compilation( policy: &Concrete, ) -> Result, CompilerError> { let mut policy_cache = PolicyCache::::new(); - let x = &*best_t(&mut policy_cache, policy, 1.0, None)?.ms; + let x = &*best_t(&mut policy_cache, policy, PositiveF64::ONE, None)?.ms; if !x.ty.mall.signed { Err(CompilerError::TopLevelSigless) } else if !x.ty.mall.non_malleable { @@ -1091,8 +1122,8 @@ pub fn best_compilation( fn best_t( policy_cache: &mut PolicyCache, policy: &Concrete, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result, CompilerError> where Pk: MiniscriptKey, @@ -1100,9 +1131,7 @@ where { best_compilations(policy_cache, policy, sat_prob, dissat_prob)? .into_iter() - .filter(|&(key, _)| { - key.ty.corr.base == types::Base::B && key.dissat_prob == dissat_prob.map(PositiveF64) - }) + .filter(|&(key, _)| key.ty.corr.base == types::Base::B && key.dissat_prob == dissat_prob) .map(|(_, val)| val) .min_by_key(|ext| PositiveF64(ext.cost_1d(sat_prob, dissat_prob))) .ok_or(CompilerError::LimitsExceeded) @@ -1113,8 +1142,8 @@ fn best( basic_type: types::Base, policy_cache: &mut PolicyCache, policy: &Concrete, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result, CompilerError> where Pk: MiniscriptKey, @@ -1126,7 +1155,7 @@ where key.ty.corr.base == basic_type && key.ty.corr.unit && val.ms.ty.mall.dissat == types::Dissat::Unique - && key.dissat_prob == dissat_prob.map(PositiveF64) + && key.dissat_prob == dissat_prob }) .map(|(_, val)| val) .min_by_key(|ext| PositiveF64(ext.cost_1d(sat_prob, dissat_prob))) @@ -1235,18 +1264,20 @@ mod tests { #[test] fn compile_q() { let policy = SPolicy::from_str("or(1@and(pk(A),pk(B)),127@pk(C))").expect("parsing"); - let compilation: TapAstElemExt = best_t(&mut BTreeMap::new(), &policy, 1.0, None).unwrap(); + let compilation: TapAstElemExt = + best_t(&mut BTreeMap::new(), &policy, PositiveF64::ONE, None).unwrap(); - assert_eq!(compilation.cost_1d(1.0, None), 87.0 + 67.0390625); + assert_eq!(compilation.cost_1d(PositiveF64::ONE, None), 87.0 + 67.0390625); assert_eq!(policy.lift().unwrap().sorted(), compilation.ms.lift().unwrap().sorted()); // compile into taproot context to avoid limit errors let policy = SPolicy::from_str( "and(and(and(or(127@thresh(2,pk(A),pk(B),thresh(2,or(127@pk(A),1@pk(B)),after(100),or(and(pk(C),after(200)),and(pk(D),sha256(66687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925))),pk(E))),1@pk(F)),sha256(66687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925)),or(127@pk(G),1@after(300))),or(127@after(400),pk(H)))" ).expect("parsing"); - let compilation: TapAstElemExt = best_t(&mut BTreeMap::new(), &policy, 1.0, None).unwrap(); + let compilation: TapAstElemExt = + best_t(&mut BTreeMap::new(), &policy, PositiveF64::ONE, None).unwrap(); - assert_eq!(compilation.cost_1d(1.0, None), 433.0 + 275.7909749348958); + assert_eq!(compilation.cost_1d(PositiveF64::ONE, None), 433.0 + 275.7909749348958); assert_eq!(policy.lift().unwrap().sorted(), compilation.ms.lift().unwrap().sorted()); } diff --git a/src/primitives/positive_f64.rs b/src/primitives/positive_f64.rs index 69013cf96..e62b6bd3c 100644 --- a/src/primitives/positive_f64.rs +++ b/src/primitives/positive_f64.rs @@ -19,6 +19,13 @@ impl PositiveF64 { #[cfg(test)] pub const ONE_QUARTER: Self = Self(0.25); + /// Given an [`Option`], if it is `Some` then add it to the value. + /// Otherwise return the unmodified value. + /// + /// Returns the sum (or original value). Does not modify in-place. + #[must_use] + pub fn conditional_add(self, other: Option) -> Self { other.map_or(self, |i| i + self) } + /// Takes an iterator over [`PositiveF64`] and produces a new iterator where /// each item is divided so that they all total to 1. /// @@ -41,6 +48,25 @@ impl PositiveF64 { pub fn n(t: &Threshold) -> Self { Self(t.n() as f64) // cast okay, worst case wil lose precision } + + /// The ratio `k`/`n` of a threshold, as a [`PositiveF64`]. Guaranteed to be + /// in the half-open range `(0, 1]`. + pub fn k_over_n(t: &Threshold) -> Self { + Self(t.k() as f64 / t.n() as f64) // casts okay, worst case wil lose precision + } + + /// One minus the ratio `k` / `n` of a threshold, as a [`PositiveF64`]. Guaranteed + /// to be in the half-open range `[0, 1)`. + /// + /// Returns `None` if the return value would be 0, which is impermissible for the + /// [`PositiveF64`] type. + pub fn one_minus_k_over_n(t: &Threshold) -> Option { + if t.is_and() { + None + } else { + Some(Self(1.0 - t.k() as f64 / t.n() as f64)) // casts okay, worst case wil lose precision + } + } } impl Eq for PositiveF64 {} From 347c334e4891dab1192f6adc79ddb822b2944cfa Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Mon, 15 Jun 2026 12:45:35 +0000 Subject: [PATCH 13/13] Remove the default `PositiveF64` constructor It is now impossible to create an invalid PositiveF64. This eliminates any possibility of division by 0 in the compiler. --- src/policy/compiler.rs | 4 ++-- src/policy/concrete.rs | 6 +++--- src/primitives/positive_f64.rs | 11 ++++++++++- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index cf3b72166..ca9338464 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -1133,7 +1133,7 @@ where .into_iter() .filter(|&(key, _)| key.ty.corr.base == types::Base::B && key.dissat_prob == dissat_prob) .map(|(_, val)| val) - .min_by_key(|ext| PositiveF64(ext.cost_1d(sat_prob, dissat_prob))) + .min_by_key(|ext| PositiveF64::new(ext.cost_1d(sat_prob, dissat_prob))) .ok_or(CompilerError::LimitsExceeded) } @@ -1158,7 +1158,7 @@ where && key.dissat_prob == dissat_prob }) .map(|(_, val)| val) - .min_by_key(|ext| PositiveF64(ext.cost_1d(sat_prob, dissat_prob))) + .min_by_key(|ext| PositiveF64::new(ext.cost_1d(sat_prob, dissat_prob))) .ok_or(CompilerError::LimitsExceeded) } diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 64f5f5ce8..dffab644f 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -618,7 +618,7 @@ impl Policy { // Stopping condition: When NONE of the inputs can be further enumerated. 'outer: loop { //--- FIND a plausible node --- - let mut prob: Reverse = Reverse(PositiveF64(0.0)); + let mut prob: Reverse = Reverse(PositiveF64::EPSILON); let mut curr_policy: Arc = Arc::new(Self::Unsatisfiable); let mut curr_pol_replace_vec: Vec<(PositiveF64, Arc)> = vec![]; let mut no_more_enum = false; @@ -1155,9 +1155,9 @@ fn with_huffman_tree( let (p1, s1) = node_weights.pop().expect("len must at least be two"); let (p2, s2) = node_weights.pop().expect("len must at least be two"); - let p = (p1.0).0 + (p2.0).0; + let p = p1.0 + p2.0; node_weights.push(( - Reverse(PositiveF64(p)), + Reverse(p), TapTree::combine(s1, s2) .expect("huffman tree cannot produce depth > 128 given sane weights"), )); diff --git a/src/primitives/positive_f64.rs b/src/primitives/positive_f64.rs index e62b6bd3c..0638f902d 100644 --- a/src/primitives/positive_f64.rs +++ b/src/primitives/positive_f64.rs @@ -9,9 +9,12 @@ use crate::Threshold; /// Ordered f64 for comparison. #[derive(Copy, Clone, PartialEq, Debug)] -pub struct PositiveF64(pub f64); +pub struct PositiveF64(f64); impl PositiveF64 { + /// The smallest representable value of a [`PositiveF64`]. + pub const EPSILON: Self = Self(f64::EPSILON); + /// The constant one. pub const ONE: Self = Self(1.0); @@ -19,6 +22,12 @@ impl PositiveF64 { #[cfg(test)] pub const ONE_QUARTER: Self = Self(0.25); + /// Attempts to create a [`PositiveF64`] from an ordinary `f64`. + pub fn new(f: f64) -> Option { + // Can likely make this function const in Rust 1.83 + (f > 0.0).then_some(Self(f)) + } + /// Given an [`Option`], if it is `Some` then add it to the value. /// Otherwise return the unmodified value. ///