Skip to content

Commit a927af6

Browse files
committed
refactor(policy): rewrite semantic policy FromStr as direct single-pass parser
Replace the recursive `rewrite_math_notation` that rescanned the input at every depth with a single linear pass that builds `Policy<Pk>` directly using an explicit frame stack. Matches `expression::Tree`'s non-recursive conventions. Addresses apoelstra's review on rust-bitcoin#914.
1 parent eb6011e commit a927af6

1 file changed

Lines changed: 190 additions & 119 deletions

File tree

src/policy/semantic.rs

Lines changed: 190 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -358,140 +358,178 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
358358

359359
impl<Pk: FromStrKey> str::FromStr for Policy<Pk> {
360360
type Err = Error;
361-
fn from_str(s: &str) -> Result<Policy<Pk>, Error> {
362-
let normalized;
363-
let input = if s.contains('∧') || s.contains('∨') || s.contains("#{") {
364-
normalized = rewrite_math_notation(s).map_err(Error::Parse)?;
365-
normalized.as_str()
366-
} else {
367-
s
368-
};
369-
let tree = expression::Tree::from_str(input)?;
370-
expression::FromTree::from_tree(tree.root())
371-
}
372-
}
373-
374-
/// Rewrites Display's mathematical form into the function-call form consumed by `expression::Tree`.
375-
fn rewrite_math_notation(s: &str) -> Result<String, crate::ParseError> {
376-
let mut out = String::with_capacity(s.len());
377-
rewrite_expr(s.trim(), &mut out)?;
378-
Ok(out)
361+
fn from_str(s: &str) -> Result<Policy<Pk>, Error> { parse_policy(s) }
379362
}
380363

381-
fn rewrite_expr(s: &str, out: &mut String) -> Result<(), crate::ParseError> {
382-
let s = s.trim();
383-
if let Some(body) = s.strip_prefix("#{") {
384-
let close = find_matching_close(body, b'}')?;
385-
let inside = &body[..close];
386-
let k_str = body[close + 1..]
387-
.trim_start()
388-
.strip_prefix('=')
389-
.ok_or_else(malformed_math)?
390-
.trim();
391-
if k_str.is_empty() || !k_str.bytes().all(|b| b.is_ascii_digit()) {
392-
return Err(malformed_math());
393-
}
394-
out.push_str("thresh(");
395-
out.push_str(k_str);
396-
for child in split_top_level(inside, ", ") {
397-
out.push(',');
398-
rewrite_expr(child, out)?;
399-
}
400-
out.push(')');
401-
Ok(())
402-
} else if let Some(body) = s.strip_prefix('(') {
403-
let close = find_matching_close(body, b')')?;
404-
let inside = &body[..close];
405-
if !body[close + 1..].is_empty() {
406-
return Err(malformed_math());
407-
}
408-
let (op, sep) = if has_top_level(inside, " ∧ ") {
409-
("and", " ∧ ")
410-
} else if has_top_level(inside, " ∨ ") {
411-
("or", " ∨ ")
412-
} else {
413-
return rewrite_expr(inside, out);
414-
};
415-
out.push_str(op);
416-
out.push('(');
417-
for (i, child) in split_top_level(inside, sep).into_iter().enumerate() {
418-
if i > 0 {
419-
out.push(',');
420-
}
421-
rewrite_expr(child, out)?;
422-
}
423-
out.push(')');
424-
Ok(())
425-
} else {
426-
out.push_str(s);
427-
Ok(())
364+
/// Parses a semantic policy from either the mathematical Display form
365+
/// (`(a ∧ b)`, `(a ∨ b)`, `#{a, b, ..} = k`) or the legacy function-call
366+
/// form (`and(..)`, `or(..)`, `thresh(k, ..)`).
367+
///
368+
/// This is a single linear pass: each byte of input is visited at most
369+
/// twice (once by the outer scan and, for bytes that belong to a terminal
370+
/// or to a nested function-call atom, once by [`expression::Tree::from_str`]
371+
/// which is itself non-recursive). The parser uses explicit frame and
372+
/// operand stacks rather than recursion.
373+
fn parse_policy<Pk: FromStrKey>(s: &str) -> Result<Policy<Pk>, Error> {
374+
// UTF-8 byte sequences for the mathematical operators.
375+
const AND_SEP: &[u8] = b" \xE2\x88\xA7 ";
376+
const OR_SEP: &[u8] = b" \xE2\x88\xA8 ";
377+
378+
#[derive(Copy, Clone, PartialEq, Eq)]
379+
enum Op {
380+
And,
381+
Or,
382+
}
383+
struct Frame<Pk: MiniscriptKey> {
384+
subs: Vec<Arc<Policy<Pk>>>,
385+
op: Option<Op>,
386+
is_thresh: bool,
428387
}
429-
}
430388

431-
fn find_matching_close(s: &str, expected: u8) -> Result<usize, crate::ParseError> {
432389
let bytes = s.as_bytes();
433-
let mut depth: i32 = 1;
434-
for (i, &ch) in bytes.iter().enumerate() {
435-
match ch {
436-
b'(' | b'{' => depth += 1,
437-
b')' | b'}' => {
438-
depth -= 1;
439-
if depth == 0 {
440-
return if ch == expected {
441-
Ok(i)
442-
} else {
443-
Err(malformed_math())
444-
};
445-
}
390+
let n = bytes.len();
391+
let mut frames: Vec<Frame<Pk>> = Vec::new();
392+
let mut cur: Option<Arc<Policy<Pk>>> = None;
393+
let mut i = 0;
394+
395+
while i < n {
396+
let b = bytes[i];
397+
if cur.is_none() {
398+
// Expecting a fresh operand: a math group `(`, a threshold `#{`,
399+
// or an atom like `pk(..)`, `UNSATISFIABLE`, etc.
400+
if b == b'(' {
401+
frames.push(Frame { subs: Vec::new(), op: None, is_thresh: false });
402+
i += 1;
403+
} else if b == b'#' && bytes.get(i + 1) == Some(&b'{') {
404+
frames.push(Frame { subs: Vec::new(), op: None, is_thresh: true });
405+
i += 2;
406+
} else {
407+
let end = scan_atom(bytes, i).ok_or_else(|| Error::Parse(malformed_math()))?;
408+
let atom = &s[i..end];
409+
let tree = expression::Tree::from_str(atom)?;
410+
let policy: Policy<Pk> = expression::FromTree::from_tree(tree.root())?;
411+
cur = Some(Arc::new(policy));
412+
i = end;
413+
}
414+
} else if bytes[i..].starts_with(AND_SEP) || bytes[i..].starts_with(OR_SEP) {
415+
let new_op = if bytes[i..].starts_with(AND_SEP) {
416+
Op::And
417+
} else {
418+
Op::Or
419+
};
420+
let frame = frames
421+
.last_mut()
422+
.ok_or_else(|| Error::Parse(malformed_math()))?;
423+
if frame.is_thresh {
424+
return Err(Error::Parse(malformed_math()));
425+
}
426+
match frame.op {
427+
None => frame.op = Some(new_op),
428+
Some(existing) if existing == new_op => {}
429+
Some(_) => return Err(Error::Parse(malformed_math())),
430+
}
431+
frame.subs.push(cur.take().unwrap());
432+
i += AND_SEP.len();
433+
} else if b == b',' && bytes.get(i + 1) == Some(&b' ') {
434+
let frame = frames
435+
.last_mut()
436+
.ok_or_else(|| Error::Parse(malformed_math()))?;
437+
if !frame.is_thresh {
438+
return Err(Error::Parse(malformed_math()));
439+
}
440+
frame.subs.push(cur.take().unwrap());
441+
i += 2;
442+
} else if b == b')' {
443+
let mut frame = frames.pop().ok_or_else(|| Error::Parse(malformed_math()))?;
444+
if frame.is_thresh {
445+
return Err(Error::Parse(malformed_math()));
446+
}
447+
frame.subs.push(cur.take().unwrap());
448+
if frame.subs.len() < 2 {
449+
return Err(Error::Parse(malformed_math()));
450+
}
451+
let op = frame.op.ok_or_else(|| Error::Parse(malformed_math()))?;
452+
let k = match op {
453+
Op::And => frame.subs.len(),
454+
Op::Or => 1,
455+
};
456+
let thresh = Threshold::new(k, frame.subs).map_err(Error::Threshold)?;
457+
cur = Some(Arc::new(Policy::Thresh(thresh)));
458+
i += 1;
459+
} else if b == b'}' {
460+
let mut frame = frames.pop().ok_or_else(|| Error::Parse(malformed_math()))?;
461+
if !frame.is_thresh {
462+
return Err(Error::Parse(malformed_math()));
446463
}
447-
_ => {}
464+
frame.subs.push(cur.take().unwrap());
465+
i += 1;
466+
if bytes.get(i..i + 3) != Some(b" = ") {
467+
return Err(Error::Parse(malformed_math()));
468+
}
469+
i += 3;
470+
let k_start = i;
471+
while i < n && bytes[i].is_ascii_digit() {
472+
i += 1;
473+
}
474+
if i == k_start {
475+
return Err(Error::Parse(malformed_math()));
476+
}
477+
let k = expression::parse_num(&s[k_start..i])
478+
.map_err(|_| Error::Parse(malformed_math()))? as usize;
479+
let thresh = Threshold::new(k, frame.subs).map_err(Error::Threshold)?;
480+
// In semantic policies we reserve `#{..} = k` for k strictly between
481+
// 1 and n; k=1 must be spelled `∨` and k=n must be spelled `∧`.
482+
if thresh.is_or() {
483+
return Err(Error::ParseThreshold(crate::ParseThresholdError::IllegalOr));
484+
}
485+
if thresh.is_and() {
486+
return Err(Error::ParseThreshold(crate::ParseThresholdError::IllegalAnd));
487+
}
488+
cur = Some(Arc::new(Policy::Thresh(thresh)));
489+
} else {
490+
return Err(Error::Parse(malformed_math()));
448491
}
449492
}
450-
Err(malformed_math())
451-
}
452493

453-
fn has_top_level(s: &str, sep: &str) -> bool {
454-
let bytes = s.as_bytes();
455-
let sep_bytes = sep.as_bytes();
456-
let mut depth: i32 = 0;
457-
let mut i = 0;
458-
while i < bytes.len() {
459-
if depth == 0 && bytes[i..].starts_with(sep_bytes) {
460-
return true;
461-
}
462-
match bytes[i] {
463-
b'(' | b'{' => depth += 1,
464-
b')' | b'}' => depth -= 1,
465-
_ => {}
466-
}
467-
i += 1;
494+
if !frames.is_empty() {
495+
return Err(Error::Parse(malformed_math()));
468496
}
469-
false
497+
let root = cur.ok_or_else(|| Error::Parse(malformed_math()))?;
498+
// `cur` is the sole outstanding reference to `root`, so `try_unwrap` always succeeds.
499+
Ok(Arc::try_unwrap(root).unwrap_or_else(|arc| (*arc).clone()))
470500
}
471501

472-
fn split_top_level<'a>(s: &'a str, sep: &str) -> Vec<&'a str> {
473-
let bytes = s.as_bytes();
474-
let sep_bytes = sep.as_bytes();
475-
let mut depth: i32 = 0;
476-
let mut start = 0;
477-
let mut i = 0;
478-
let mut out = Vec::new();
479-
while i < bytes.len() {
480-
if depth == 0 && bytes[i..].starts_with(sep_bytes) {
481-
out.push(&s[start..i]);
482-
i += sep_bytes.len();
483-
start = i;
484-
continue;
502+
/// Finds the end of a single atom starting at `start`.
503+
///
504+
/// An atom is a name (any run of non-delimiter bytes) optionally followed
505+
/// by a balanced `(...)` argument list. Returns `None` if the atom is
506+
/// empty or its parentheses are unbalanced.
507+
fn scan_atom(bytes: &[u8], start: usize) -> Option<usize> {
508+
fn is_delim(b: u8) -> bool { matches!(b, b' ' | b',' | b'(' | b')' | b'{' | b'}' | b'#') }
509+
510+
let mut i = start;
511+
while i < bytes.len() && !is_delim(bytes[i]) {
512+
i += 1;
513+
}
514+
if i == start {
515+
return None;
516+
}
517+
if i < bytes.len() && bytes[i] == b'(' {
518+
let mut depth: u32 = 1;
519+
i += 1;
520+
while i < bytes.len() && depth > 0 {
521+
match bytes[i] {
522+
b'(' => depth += 1,
523+
b')' => depth -= 1,
524+
_ => {}
525+
}
526+
i += 1;
485527
}
486-
match bytes[i] {
487-
b'(' | b'{' => depth += 1,
488-
b')' | b'}' => depth -= 1,
489-
_ => {}
528+
if depth != 0 {
529+
return None;
490530
}
491-
i += 1;
492531
}
493-
out.push(&s[start..]);
494-
out
532+
Some(i)
495533
}
496534

497535
fn malformed_math() -> crate::ParseError {
@@ -931,6 +969,39 @@ mod tests {
931969
.is_ok());
932970
}
933971

972+
#[test]
973+
fn parse_math_notation() {
974+
// Nested groups round-trip through the direct parser.
975+
let a = StringPolicy::from_str("((pk(A) ∧ pk(B)) ∨ pk(C))").unwrap();
976+
let b = StringPolicy::from_str("or(and(pk(A),pk(B)),pk(C))").unwrap();
977+
assert_eq!(a, b);
978+
979+
// Thresh with explicit k strictly between 1 and n.
980+
let a = StringPolicy::from_str("#{pk(A), pk(B), pk(C), pk(D)} = 3").unwrap();
981+
let b = StringPolicy::from_str("thresh(3,pk(A),pk(B),pk(C),pk(D))").unwrap();
982+
assert_eq!(a, b);
983+
984+
// Thresh nested inside a math group.
985+
let a = StringPolicy::from_str("(pk(A) ∧ #{pk(B), pk(C), pk(D)} = 2)").unwrap();
986+
let b = StringPolicy::from_str("and(pk(A),thresh(2,pk(B),pk(C),pk(D)))").unwrap();
987+
assert_eq!(a, b);
988+
989+
// Top-level terminals still parse.
990+
assert_eq!(StringPolicy::from_str("UNSATISFIABLE").unwrap(), Policy::Unsatisfiable);
991+
assert_eq!(StringPolicy::from_str("TRIVIAL").unwrap(), Policy::Trivial);
992+
993+
// Mixing ∧ and ∨ in the same group is not valid (Display never emits it).
994+
assert!(StringPolicy::from_str("(pk(A) ∧ pk(B) ∨ pk(C))").is_err());
995+
// Unbalanced / unterminated math groups.
996+
assert!(StringPolicy::from_str("(pk(A) ∧ pk(B)").is_err());
997+
assert!(StringPolicy::from_str("#{pk(A), pk(B)} = ").is_err());
998+
// k=1 and k=n spelled as thresh must be rejected in semantic policies.
999+
assert!(StringPolicy::from_str("#{pk(A), pk(B)} = 1").is_err());
1000+
assert!(StringPolicy::from_str("#{pk(A), pk(B)} = 2").is_err());
1001+
// Trailing garbage after a valid expression.
1002+
assert!(StringPolicy::from_str("pk(A)xyz").is_err());
1003+
}
1004+
9341005
#[test]
9351006
fn semantic_analysis() {
9361007
let policy = StringPolicy::from_str("pk()").unwrap();

0 commit comments

Comments
 (0)