From 11dbb430be01131e88b8238efac4f9adf0ac769d Mon Sep 17 00:00:00 2001 From: urisinger Date: Tue, 24 Dec 2024 18:37:03 +0200 Subject: [PATCH 01/16] implement some functions and if statment:Wq --- libraries/math-parser/src/ast.rs | 4 + libraries/math-parser/src/constants.rs | 223 ++++++++++++++++++++++++- libraries/math-parser/src/executer.rs | 11 +- libraries/math-parser/src/grammer.pest | 12 +- libraries/math-parser/src/lib.rs | 44 +++++ libraries/math-parser/src/parser.rs | 57 +++++-- libraries/math-parser/src/value.rs | 46 ++++- 7 files changed, 370 insertions(+), 27 deletions(-) diff --git a/libraries/math-parser/src/ast.rs b/libraries/math-parser/src/ast.rs index 4c42fc4b1d..a107bc492b 100644 --- a/libraries/math-parser/src/ast.rs +++ b/libraries/math-parser/src/ast.rs @@ -56,6 +56,9 @@ pub enum BinaryOp { Mul, Div, Pow, + Leq, + Geq, + Eq, } #[derive(Debug, PartialEq, Clone, Copy)] @@ -72,4 +75,5 @@ pub enum Node { FnCall { name: String, expr: Vec }, BinOp { lhs: Box, op: BinaryOp, rhs: Box }, UnaryOp { expr: Box, op: UnaryOp }, + Conditional { condition: Box, if_block: Box, else_block: Box }, } diff --git a/libraries/math-parser/src/constants.rs b/libraries/math-parser/src/constants.rs index c010d13253..f439be19d8 100644 --- a/libraries/math-parser/src/constants.rs +++ b/libraries/math-parser/src/constants.rs @@ -2,7 +2,7 @@ use crate::value::{Number, Value}; use lazy_static::lazy_static; use num_complex::{Complex, ComplexFloat}; use std::collections::HashMap; -use std::f64::consts::PI; +use std::f64::consts::{LN_2, PI}; type FunctionImplementation = Box Option + Send + Sync>; lazy_static! { @@ -116,6 +116,227 @@ lazy_static! { _ => None, }), ); + // Hyperbolic Functions + map.insert( + "sinh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sinh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sinh()))), + _ => None, + }), + ); + + map.insert( + "cosh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.cosh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.cosh()))), + _ => None, + }), + ); + + map.insert( + "tanh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.tanh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.tanh()))), + _ => None, + }), + ); + + // Inverse Hyperbolic Functions + map.insert( + "asinh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.asinh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.asinh()))), + _ => None, + }), + ); + + map.insert( + "acosh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.acosh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.acosh()))), + _ => None, + }), + ); + + map.insert( + "atanh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.atanh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.atanh()))), + _ => None, + }), + ); + + // Logarithm Functions + map.insert( + "ln", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.ln()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.ln()))), + _ => None, + }), + ); + + map.insert( + "log", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.log10()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.log10()))), + [Value::Number(n), Value::Number(base)] => { + // Custom base logarithm using change of base formula + let compute_log = |x: f64, b: f64| -> f64 { x.ln() / b.ln() }; + match (n, base) { + (Number::Real(x), Number::Real(b)) => Some(Value::Number(Number::Real(compute_log(*x, *b)))), + _ => None, + } + } + _ => None, + }), + ); + + map.insert( + "log2", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.log2()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex / LN_2))), + _ => None, + }), + ); + + // Root Functions + map.insert( + "sqrt", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sqrt()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sqrt()))), + _ => None, + }), + ); + + map.insert( + "cbrt", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.cbrt()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.powf(1.0/3.0)))), + _ => None, + }), + ); + + // Geometry Functions + map.insert( + "hypot", + Box::new(|values| match values { + [Value::Number(Number::Real(a)), Value::Number(Number::Real(b))] => { + Some(Value::Number(Number::Real(a.hypot(*b)))) + }, + _ => None, + }), + ); + + // Mapping Functions + map.insert( + "abs", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.abs()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Real(complex.abs()))), + _ => None, + }), + ); + + map.insert( + "floor", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.floor()))), + _ => None, + }), + ); + + map.insert( + "ceil", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.ceil()))), + _ => None, + }), + ); + + map.insert( + "round", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.round()))), + _ => None, + }), + ); + + map.insert( + "clamp", + Box::new(|values| match values { + [Value::Number(Number::Real(x)), Value::Number(Number::Real(min)), Value::Number(Number::Real(max))] => { + Some(Value::Number(Number::Real(x.clamp(*min, *max)))) + }, + _ => None, + }), + ); + + map.insert( + "lerp", + Box::new(|values| match values { + [Value::Number(Number::Real(a)), Value::Number(Number::Real(b)), Value::Number(Number::Real(t))] => { + Some(Value::Number(Number::Real(a + (b - a) * t))) + }, + _ => None, + }), + ); + + // Complex Number Functions + map.insert( + "real", + Box::new(|values| match values { + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Real(complex.re))), + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(*real))), + _ => None, + }), + ); + + map.insert( + "imag", + Box::new(|values| match values { + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Real(complex.im))), + [Value::Number(Number::Real(_))] => Some(Value::Number(Number::Real(0.0))), + _ => None, + }), + ); + + // Logical Functions + map.insert( + "isnan", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(if real.is_nan() { 1.0 } else { 0.0 }))), + _ => None, + }), + ); + + map.insert( + "eq", + Box::new(|values| match values { + [Value::Number(a), Value::Number(b)] => Some(Value::Number(Number::Real(if a == b { 1.0 } else { 0.0 }))), + _ => None, + }), + ); + + map.insert( + "greater", + Box::new(|values| match values { + [Value::Number(Number::Real(a)), Value::Number(Number::Real(b))] => { + Some(Value::Number(Number::Real(if a > b { 1.0 } else { 0.0 }))) + }, + _ => None, + }), + ); map }; diff --git a/libraries/math-parser/src/executer.rs b/libraries/math-parser/src/executer.rs index 9d6180f1ab..85871d769c 100644 --- a/libraries/math-parser/src/executer.rs +++ b/libraries/math-parser/src/executer.rs @@ -2,6 +2,7 @@ use crate::ast::{Literal, Node}; use crate::constants::DEFAULT_FUNCTIONS; use crate::context::{EvalContext, FunctionProvider, ValueProvider}; use crate::value::{Number, Value}; +use num_complex::Complex; use thiserror::Error; #[derive(Debug, Error)] @@ -24,7 +25,7 @@ impl Node { }, Node::BinOp { lhs, op, rhs } => match (lhs.eval(context)?, rhs.eval(context)?) { - (Value::Number(lhs), Value::Number(rhs)) => Ok(Value::Number(lhs.binary_op(*op, rhs))), + (Value::Number(lhs), Value::Number(rhs)) => Ok(Value::Number(lhs.binary_op(*op, rhs).ok_or(EvalError::TypeError)?)), }, Node::UnaryOp { expr, op } => match expr.eval(context)? { Value::Number(num) => Ok(Value::Number(num.unary_op(*op))), @@ -40,6 +41,14 @@ impl Node { context.get_value(name).ok_or_else(|| EvalError::MissingFunction(name.to_string())) } } + Node::Conditional { condition, if_block, else_block } => { + let condition = match condition.eval(context)? { + Value::Number(Number::Real(number)) => number != 0.0, + Value::Number(Number::Complex(number)) => number != Complex::ZERO, + }; + + if condition { if_block.eval(context) } else { else_block.eval(context) } + } } } } diff --git a/libraries/math-parser/src/grammer.pest b/libraries/math-parser/src/grammer.pest index d7a61939df..d0aa643aea 100644 --- a/libraries/math-parser/src/grammer.pest +++ b/libraries/math-parser/src/grammer.pest @@ -4,24 +4,26 @@ WHITESPACE = _{ " " | "\t" } program = _{ SOI ~ expr ~ EOI } expr = { atom ~ (infix ~ atom)* } -atom = _{ prefix? ~ primary ~ postfix? } -infix = _{ add | sub | mul | div | pow | paren } +atom = _{ prefix? ~ primary } +infix = _{ add | sub | mul | div | pow | leq | geq | eq | paren } add = { "+" } // Addition sub = { "-" } // Subtraction mul = { "*" } // Multiplication div = { "/" } // Division mod = { "%" } // Modulo pow = { "^" } // Exponentiation +leq = { "<" } +geq = { ">" } +eq = { "=="} paren = { "" } // Implicit multiplication operator prefix = _{ neg | sqrt } neg = { "-" } // Negation sqrt = { "sqrt" } -postfix = _{ fac } -fac = { "!" } // Factorial -primary = _{ ("(" ~ expr ~ ")") | lit | constant | fn_call | ident } +primary = _{ ("(" ~ expr ~ ")") | conditional | lit | constant | fn_call | ident } +conditional = { "if" ~ "(" ~ expr ~ ")" ~ "{" ~ expr ~ "}" ~ "else" ~ "{" ~ expr ~ "}"} fn_call = { ident ~ "(" ~ expr ~ ("," ~ expr)* ~ ")" } ident = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } lit = { unit | ((float | int) ~ unit?) } diff --git a/libraries/math-parser/src/lib.rs b/libraries/math-parser/src/lib.rs index e596d78dff..d3e7d04df9 100644 --- a/libraries/math-parser/src/lib.rs +++ b/libraries/math-parser/src/lib.rs @@ -145,5 +145,49 @@ mod tests { trig_tan_pi_div_four: "tan(pi/4)" => (1.0, Unit::BASE_UNIT), trig_sin_tau: "sin(tau)" => (0.0, Unit::BASE_UNIT), trig_cos_tau_div_two: "cos(tau/2)" => (-1.0, Unit::BASE_UNIT), + + // Basic if statements + if_true_condition: "if(1){5} else {3}" => (5., Unit::BASE_UNIT), + if_false_condition: "if(0){5} else {3}" => (3., Unit::BASE_UNIT), + + // Arithmetic conditions + if_arithmetic_true: "if(2+2-4){1} else {0}" => (0., Unit::BASE_UNIT), + if_arithmetic_false: "if(3*2-5){1} else {0}" => (1., Unit::BASE_UNIT), + + // Nested arithmetic + if_complex_arithmetic: "if((5+3)*(2-1)){10} else {20}" => (10., Unit::BASE_UNIT), + if_with_division: "if(8/4-2){15} else {25}" => (15., Unit::BASE_UNIT), + + // Constants in conditions + if_with_pi: "if(pi > 3){1} else {0}" => (1., Unit::BASE_UNIT), + if_with_e: "if(e < 3){1} else {0}" => (0., Unit::BASE_UNIT), + + // Functions in conditions + if_with_sqrt: "if(sqrt(16) == 4){1} else {0}" => (1., Unit::BASE_UNIT), + if_with_sin: "if(sin(pi)){1} else {0}" => (0., Unit::BASE_UNIT), + + // Nested if statements + nested_if: "if(1){if(0){1} else {2}} else {3}" => (2., Unit::BASE_UNIT), + nested_if_complex: "if(2-2){if(1){5} else {6}} else {if(1){7} else {8}}" => (7., Unit::BASE_UNIT), + + // If statements with variables + if_with_var: "if(x > 0){1} else {0}" => (1., Unit::BASE_UNIT), + if_var_arithmetic: "if(x + 3 > 5){1} else {0}" => (0., Unit::BASE_UNIT), + + // Mixed operations in conditions and blocks + if_complex_condition: "if(sqrt(16) + sin(pi) < 5){2*pi} else {3*e}" => (2. * std::f64::consts::PI, Unit::BASE_UNIT), + if_complex_blocks: "if(1){2*sqrt(16) + sin(pi/2)} else {3*cos(0) + 4}" => (8., Unit::BASE_UNIT), + + // Edge cases + if_zero: "if(0.0){1} else {2}" => (2., Unit::BASE_UNIT), + if_negative: "if(-1){1} else {2}" => (1., Unit::BASE_UNIT), + if_infinity: "if(inf){1} else {2}" => (1., Unit::BASE_UNIT), + + // Units in if statements + if_with_units: "if(5m > 3m){10s} else {20s}" => (10., Unit::TIME), + if_mixed_units: "if(2km/h > 1m/s){5kg} else {10kg}" => (10., Unit::MASS), + + // Complex nested expressions + if_nested_expr: "if((sqrt(16) + 2) * (sin(pi) + 1)){3 + 4 * 2} else {5 - 2 / 1}" => (11., Unit::BASE_UNIT), } } diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index 101995d87b..f827914d3a 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -11,7 +11,7 @@ use std::num::{ParseFloatError, ParseIntError}; use thiserror::Error; #[derive(Parser)] -#[grammar = "./grammer.pest"] // Point to the grammar file +#[grammar = "./grammer.pest"] struct ExprParser; lazy_static! { @@ -20,7 +20,8 @@ lazy_static! { .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::sub, Assoc::Left)) .op(Op::infix(Rule::mul, Assoc::Left) | Op::infix(Rule::div, Assoc::Left) | Op::infix(Rule::paren, Assoc::Left)) .op(Op::infix(Rule::pow, Assoc::Right)) - .op(Op::postfix(Rule::fac) | Op::postfix(Rule::EOI)) + .op(Op::infix(Rule::geq, Assoc::Left) | Op::infix(Rule::leq, Assoc::Left) | Op::infix(Rule::eq, Assoc::Left)) + .op(Op::postfix(Rule::EOI)) .op(Op::prefix(Rule::sqrt)) .op(Op::prefix(Rule::neg)) }; @@ -185,6 +186,22 @@ fn parse_expr(pairs: Pairs) -> Result<(Node, NodeMetadata), ParseError> { let value = primary.as_str().parse::()?; (Node::Lit(Literal::Float(value)), NodeMetadata::new(Unit::BASE_UNIT)) } + Rule::conditional => { + let mut inner = primary.into_inner(); + + let condition = parse_expr(inner.next().expect("expected conditional to have condition").into_inner())?; + let if_block = parse_expr(inner.next().expect("expected conditional to have if block").into_inner())?; + let else_block = parse_expr(inner.next().expect("expected conditional to have else block").into_inner())?; + + ( + Node::Conditional { + condition: Box::new(condition.0), + if_block: Box::new(if_block.0), + else_block: Box::new(else_block.0), + }, + if_block.1, + ) + } rule => unreachable!("unexpected rule: {:?}", rule), }) }) @@ -217,19 +234,10 @@ fn parse_expr(pairs: Pairs) -> Result<(Node, NodeMetadata), ParseError> { Ok((node, NodeMetadata::new(unit))) }) .map_postfix(|lhs, op| { - let (lhs_node, lhs_metadata) = lhs?; - let op = match op.as_rule() { - Rule::EOI => return Ok((lhs_node, lhs_metadata)), - Rule::fac => UnaryOp::Fac, + Rule::EOI => return lhs, rule => unreachable!("unexpected rule: {:?}", rule), }; - - if !lhs_metadata.unit.is_base() { - return Err(ParseError::Type(TypeError::InvalidUnaryOp(lhs_metadata.unit, op))); - } - - Ok((Node::UnaryOp { expr: Box::new(lhs_node), op }, lhs_metadata)) }) .map_infix(|lhs, op, rhs| { let (lhs, lhs_metadata) = lhs?; @@ -242,6 +250,9 @@ fn parse_expr(pairs: Pairs) -> Result<(Node, NodeMetadata), ParseError> { Rule::div => BinaryOp::Div, Rule::pow => BinaryOp::Pow, Rule::paren => BinaryOp::Mul, + Rule::leq => BinaryOp::Leq, + Rule::geq => BinaryOp::Geq, + Rule::eq => BinaryOp::Eq, rule => unreachable!("unexpected rule: {:?}", rule), }; @@ -269,18 +280,25 @@ fn parse_expr(pairs: Pairs) -> Result<(Node, NodeMetadata), ParseError> { BinaryOp::Pow => { return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); } + BinaryOp::Eq | BinaryOp::Leq | BinaryOp::Geq => { + if lhs_unit == rhs_unit { + Unit::base_unit() + } else { + return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); + } + } }, (true, false) => match op { BinaryOp::Add | BinaryOp::Sub => return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))), BinaryOp::Pow => { //TODO: improve error type - //TODO: support 1 / int + //TODO: support fractions if let Ok(Value::Number(Number::Real(val))) = rhs.eval(&EvalContext::default()) { if (val - val as i32 as f64).abs() <= f64::EPSILON { Unit { - length: lhs_unit.length * val as i32, - mass: lhs_unit.mass * val as i32, + length: lhs_unit.length * val.round() as i32, + mass: lhs_unit.mass * val.round() as i32, time: lhs_unit.time * val as i32, } } else { @@ -374,6 +392,15 @@ mod tests { op: BinaryOp::Pow, rhs: Box::new(Node::Lit(Literal::Float(2.0))), }), + }, + test_conditional_expr: "if (x+3) {0} else {1}" => Node::Conditional{ + condition: Box::new(Node::BinOp{ + lhs: Box::new(Node::Var("x".to_string())), + op: BinaryOp::Add, + rhs: Box::new(Node::Lit(Literal::Float(3.0))), + }), + if_block: Box::new(Node::Lit(Literal::Float(0.0))), + else_block: Box::new(Node::Lit(Literal::Float(1.0))), } } } diff --git a/libraries/math-parser/src/value.rs b/libraries/math-parser/src/value.rs index 3577f3ea60..387959d9f8 100644 --- a/libraries/math-parser/src/value.rs +++ b/libraries/math-parser/src/value.rs @@ -52,7 +52,7 @@ impl std::fmt::Display for Number { } impl Number { - pub fn binary_op(self, op: BinaryOp, other: Number) -> Number { + pub fn binary_op(self, op: BinaryOp, other: Number) -> Option { match (self, other) { (Number::Real(lhs), Number::Real(rhs)) => { let result = match op { @@ -61,8 +61,29 @@ impl Number { BinaryOp::Mul => lhs * rhs, BinaryOp::Div => lhs / rhs, BinaryOp::Pow => lhs.powf(rhs), + BinaryOp::Leq => { + if lhs > rhs { + 1.0 + } else { + 0.0 + } + } + BinaryOp::Geq => { + if lhs < rhs { + 1.0 + } else { + 0.0 + } + } + BinaryOp::Eq => { + if lhs == rhs { + 1.0 + } else { + 0.0 + } + } }; - Number::Real(result) + Some(Number::Real(result)) } (Number::Complex(lhs), Number::Complex(rhs)) => { @@ -72,8 +93,21 @@ impl Number { BinaryOp::Mul => lhs * rhs, BinaryOp::Div => lhs / rhs, BinaryOp::Pow => lhs.powc(rhs), + BinaryOp::Leq => { + return None; + } + BinaryOp::Geq => { + return None; + } + BinaryOp::Eq => { + if lhs == rhs { + return Some(Number::Real(1.0)); + } else { + return Some(Number::Real(0.0)); + } + } }; - Number::Complex(result) + Some(Number::Complex(result)) } (Number::Real(lhs), Number::Complex(rhs)) => { @@ -84,8 +118,9 @@ impl Number { BinaryOp::Mul => lhs_complex * rhs, BinaryOp::Div => lhs_complex / rhs, BinaryOp::Pow => lhs_complex.powc(rhs), + _ => return None, }; - Number::Complex(result) + Some(Number::Complex(result)) } (Number::Complex(lhs), Number::Real(rhs)) => { @@ -96,8 +131,9 @@ impl Number { BinaryOp::Mul => lhs * rhs_complex, BinaryOp::Div => lhs / rhs_complex, BinaryOp::Pow => lhs.powf(rhs), + _ => return None, }; - Number::Complex(result) + Some(Number::Complex(result)) } } } From f8dab70bc6d8ed6d24e07333c2e8557ed02df0fa Mon Sep 17 00:00:00 2001 From: urisinger Date: Tue, 3 Jun 2025 11:44:57 +0300 Subject: [PATCH 02/16] switch to chumsky, ignore units for now --- Cargo.lock | 123 ++++++++++++----------- libraries/math-parser/Cargo.toml | 3 +- libraries/math-parser/src/ast.rs | 2 + libraries/math-parser/src/lib.rs | 136 ++++++++++++------------- libraries/math-parser/src/parser.rs | 149 +++++++++++++++++++++++----- libraries/math-parser/src/value.rs | 33 ++---- 6 files changed, 264 insertions(+), 182 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c39cb7752b..20b77369c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -82,6 +82,12 @@ dependencies = [ "alloc-no-stdlib", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android-activity" version = "0.5.2" @@ -737,6 +743,20 @@ dependencies = [ "windows-link", ] +[[package]] +name = "chumsky" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14377e276b2c8300513dff55ba4cc4142b44e5d6de6d00eb5b2307d650bb4ec1" +dependencies = [ + "hashbrown 0.15.2", + "regex-automata 0.3.9", + "serde", + "stacker", + "unicode-ident", + "unicode-segmentation", +] + [[package]] name = "ciborium" version = "0.2.2" @@ -2418,6 +2438,8 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" dependencies = [ + "allocator-api2", + "equivalent", "foldhash", ] @@ -3225,7 +3247,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -3339,11 +3361,10 @@ checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" name = "math-parser" version = "0.0.0" dependencies = [ + "chumsky", "criterion", "lazy_static", "num-complex", - "pest", - "pest_derive", "thiserror 2.0.12", ] @@ -4134,51 +4155,6 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" -[[package]] -name = "pest" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" -dependencies = [ - "memchr", - "thiserror 2.0.12", - "ucd-trie", -] - -[[package]] -name = "pest_derive" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816518421cfc6887a0d62bf441b6ffb4536fcc926395a69e1a85852d4363f57e" -dependencies = [ - "pest", - "pest_generator", -] - -[[package]] -name = "pest_generator" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d1396fd3a870fc7838768d171b4616d5c91f6cc25e377b673d714567d99377b" -dependencies = [ - "pest", - "pest_meta", - "proc-macro2", - "quote", - "syn 2.0.99", -] - -[[package]] -name = "pest_meta" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1e58089ea25d717bfd31fb534e4f3afcc2cc569c70de3e239778991ea3b7dea" -dependencies = [ - "once_cell", - "pest", - "sha2", -] - [[package]] name = "petgraph" version = "0.6.5" @@ -4592,6 +4568,15 @@ version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac" +[[package]] +name = "psm" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" +dependencies = [ + "cc", +] + [[package]] name = "publicsuffix" version = "2.3.0" @@ -4942,8 +4927,19 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.7.5", ] [[package]] @@ -4954,9 +4950,15 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.5", ] +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + [[package]] name = "regex-syntax" version = "0.8.5" @@ -5755,6 +5757,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -6670,12 +6685,6 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" -[[package]] -name = "ucd-trie" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" - [[package]] name = "unic-char-property" version = "0.9.0" diff --git a/libraries/math-parser/Cargo.toml b/libraries/math-parser/Cargo.toml index b84da885f6..5b13952fc8 100644 --- a/libraries/math-parser/Cargo.toml +++ b/libraries/math-parser/Cargo.toml @@ -8,11 +8,10 @@ description = "Parser for Graphite style mathematics expressions" license = "MIT OR Apache-2.0" [dependencies] -pest = "2.7" -pest_derive = "2.7.11" thiserror = "2.0" lazy_static = "1.5" num-complex = "0.4" +chumsky = "0.10" [dev-dependencies] criterion = "0.5" diff --git a/libraries/math-parser/src/ast.rs b/libraries/math-parser/src/ast.rs index a107bc492b..a73232b4d5 100644 --- a/libraries/math-parser/src/ast.rs +++ b/libraries/math-parser/src/ast.rs @@ -57,7 +57,9 @@ pub enum BinaryOp { Div, Pow, Leq, + Lt, Geq, + Gt, Eq, } diff --git a/libraries/math-parser/src/lib.rs b/libraries/math-parser/src/lib.rs index d3e7d04df9..a8299b3b1e 100644 --- a/libraries/math-parser/src/lib.rs +++ b/libraries/math-parser/src/lib.rs @@ -13,10 +13,10 @@ use executer::EvalError; use parser::ParseError; use value::Value; -pub fn evaluate(expression: &str) -> Result<(Result, Unit), ParseError> { +pub fn evaluate(expression: &str) -> Result, ParseError> { let expr = ast::Node::try_parse_from_str(expression); let context = EvalContext::default(); - expr.map(|(node, unit)| (node.eval(&context), unit)) + expr.map(|node| node.eval(&context)) } #[cfg(test)] @@ -28,21 +28,21 @@ mod tests { const EPSILON: f64 = 1e-10_f64; macro_rules! test_end_to_end{ - ($($name:ident: $input:expr_2021 => ($expected_value:expr_2021, $expected_unit:expr_2021)),* $(,)?) => { + ($($name:ident: $input:expr_2021 => $expected_value:expr_2021),* $(,)?) => { $( #[test] fn $name() { let expected_value = $expected_value; - let expected_unit = $expected_unit; let expr = ast::Node::try_parse_from_str($input); let context = EvalContext::default(); - let (actual_value, actual_unit) = expr.map(|(node, unit)| (node.eval(&context), unit)).unwrap(); + dbg!(&expr); + + let actual_value = expr.map(|node| node.eval(&context)).unwrap(); let actual_value = actual_value.unwrap(); - assert!(actual_unit == expected_unit, "Expected unit {:?} but found unit {:?}", expected_unit, actual_unit); let expected_value = expected_value.into(); @@ -86,108 +86,96 @@ mod tests { test_end_to_end! { // Basic arithmetic and units - infix_addition: "5 + 5" => (10., Unit::BASE_UNIT), - infix_subtraction_units: "5m - 3m" => (2., Unit::LENGTH), - infix_multiplication_units: "4s * 4s" => (16., Unit { length: 0, mass: 0, time: 2 }), - infix_division_units: "8m/2s" => (4., Unit::VELOCITY), + infix_addition: "5 + 5" => 10., // Order of operations - order_of_operations_negative_prefix: "-10 + 5" => (-5., Unit::BASE_UNIT), - order_of_operations_add_multiply: "5+1*1+5" => (11., Unit::BASE_UNIT), - order_of_operations_add_negative_multiply: "5+(-1)*1+5" => (9., Unit::BASE_UNIT), - order_of_operations_sqrt: "sqrt25 + 11" => (16., Unit::BASE_UNIT), - order_of_operations_sqrt_expression: "sqrt(25+11)" => (6., Unit::BASE_UNIT), + order_of_operations_negative_prefix: "-10 + 5" => -5., + order_of_operations_add_multiply: "5+1*1+5" => 11., + order_of_operations_add_negative_multiply: "5+(-1)*1+5" => 9., + order_of_operations_sqrt: "sqrt25 + 11" => 16., + order_of_operations_sqrt_expression: "sqrt(25+11)" => 6., // Parentheses and nested expressions - parentheses_nested_multiply: "(5 + 3) * (2 + 6)" => (64., Unit::BASE_UNIT), - parentheses_mixed_operations: "2 * (3 + 5 * (2 + 1))" => (36., Unit::BASE_UNIT), - parentheses_divide_add_multiply: "10 / (2 + 3) + (7 * 2)" => (16., Unit::BASE_UNIT), + parentheses_nested_multiply: "(5 + 3) * (2 + 6)" => 64., + parentheses_mixed_operations: "2 * (3 + 5 * (2 + 1))" => 36., + parentheses_divide_add_multiply: "10 / (2 + 3) + (7 * 2)" => 16., // Square root and nested square root - sqrt_chain_operations: "sqrt(16) + sqrt(9) * sqrt(4)" => (10., Unit::BASE_UNIT), - sqrt_nested: "sqrt(sqrt(81))" => (3., Unit::BASE_UNIT), - sqrt_divide_expression: "sqrt((25 + 11) / 9)" => (2., Unit::BASE_UNIT), + sqrt_chain_operations: "sqrt(16) + sqrt(9) * sqrt(4)" => 10., + sqrt_nested: "sqrt(sqrt(81))" => 3., + sqrt_divide_expression: "sqrt((25 + 11) / 9)" => 2., // Mixed square root and units - sqrt_multiply_units: "sqrt(16) * 2g + 5g" => (13., Unit::MASS), - sqrt_add_multiply: "sqrt(49) - 1 + 2 * 3" => (12., Unit::BASE_UNIT), - sqrt_addition_multiply: "(sqrt(36) + 2) * 2" => (16., Unit::BASE_UNIT), + sqrt_add_multiply: "sqrt(49) - 1 + 2 * 3" => 12., + sqrt_addition_multiply: "(sqrt(36) + 2) * 2" => 16., // Exponentiation - exponent_single: "2^3" => (8., Unit::BASE_UNIT), - exponent_mixed_operations: "2^3 + 4^2" => (24., Unit::BASE_UNIT), - exponent_nested: "2^(3+1)" => (16., Unit::BASE_UNIT), + exponent_single: "2^3" => 8., + exponent_mixed_operations: "2^3 + 4^2" => 24., + exponent_nested: "2^(3+1)" => 16., // Operations with negative values - negative_units_add_multiply: "-5s + (-3 * 2)s" => (-11., Unit::TIME), - negative_nested_parentheses: "-(5 + 3 * (2 - 1))" => (-8., Unit::BASE_UNIT), - negative_sqrt_addition: "-(sqrt(16) + sqrt(9))" => (-7., Unit::BASE_UNIT), - multiply_sqrt_subtract: "5 * 2 + sqrt(16) / 2 - 3" => (9., Unit::BASE_UNIT), - add_multiply_subtract_sqrt: "4 + 3 * (2 + 1) - sqrt(25)" => (8., Unit::BASE_UNIT), - add_sqrt_subtract_nested_multiply: "10 + sqrt(64) - (5 * (2 + 1))" => (3., Unit::BASE_UNIT), + negative_nested_parentheses: "-(5 + 3 * (2 - 1))" => -8., + negative_sqrt_addition: "-(sqrt(16) + sqrt(9))" => -7., + multiply_sqrt_subtract: "5 * 2 + sqrt(16) / 2 - 3" => 9., + add_multiply_subtract_sqrt: "4 + 3 * (2 + 1) - sqrt(25)" => 8., + add_sqrt_subtract_nested_multiply: "10 + sqrt(64) - (5 * (2 + 1))" => 3., // Mathematical constants - constant_pi: "pi" => (std::f64::consts::PI, Unit::BASE_UNIT), - constant_e: "e" => (std::f64::consts::E, Unit::BASE_UNIT), - constant_phi: "phi" => (1.61803398875, Unit::BASE_UNIT), - constant_tau: "tau" => (2.0 * std::f64::consts::PI, Unit::BASE_UNIT), - constant_infinity: "inf" => (f64::INFINITY, Unit::BASE_UNIT), - constant_infinity_symbol: "∞" => (f64::INFINITY, Unit::BASE_UNIT), - multiply_pi: "2 * pi" => (2.0 * std::f64::consts::PI, Unit::BASE_UNIT), - add_e_constant: "e + 1" => (std::f64::consts::E + 1.0, Unit::BASE_UNIT), - multiply_phi_constant: "phi * 2" => (1.61803398875 * 2.0, Unit::BASE_UNIT), - exponent_tau: "2^tau" => (2f64.powf(2.0 * std::f64::consts::PI), Unit::BASE_UNIT), - infinity_subtract_large_number: "inf - 1000" => (f64::INFINITY, Unit::BASE_UNIT), + constant_pi: "pi" => std::f64::consts::PI, + constant_e: "e" => std::f64::consts::E, + constant_phi: "phi" => 1.61803398875, + constant_tau: "tau" => 2.0 * std::f64::consts::PI, + constant_infinity: "inf" => f64::INFINITY, + constant_infinity_symbol: "∞" => f64::INFINITY, + multiply_pi: "2 * pi" => 2.0 * std::f64::consts::PI, + add_e_constant: "e + 1" => std::f64::consts::E + 1.0, + multiply_phi_constant: "phi * 2" => 1.61803398875 * 2.0, + exponent_tau: "2^tau" => 2f64.powf(2.0 * std::f64::consts::PI), + infinity_subtract_large_number: "inf - 1000" => f64::INFINITY, // Trigonometric functions - trig_sin_pi: "sin(pi)" => (0.0, Unit::BASE_UNIT), - trig_cos_zero: "cos(0)" => (1.0, Unit::BASE_UNIT), - trig_tan_pi_div_four: "tan(pi/4)" => (1.0, Unit::BASE_UNIT), - trig_sin_tau: "sin(tau)" => (0.0, Unit::BASE_UNIT), - trig_cos_tau_div_two: "cos(tau/2)" => (-1.0, Unit::BASE_UNIT), + trig_sin_pi: "sin(pi)" => 0.0, + trig_cos_zero: "cos(0)" => 1.0, + trig_tan_pi_div_four: "tan(pi/4)" => 1.0, + trig_sin_tau: "sin(tau)" => 0.0, + trig_cos_tau_div_two: "cos(tau/2)" => -1.0, // Basic if statements - if_true_condition: "if(1){5} else {3}" => (5., Unit::BASE_UNIT), - if_false_condition: "if(0){5} else {3}" => (3., Unit::BASE_UNIT), + if_true_condition: "if(1){5} else {3}" => 5., + if_false_condition: "if(0){5} else {3}" => 3., // Arithmetic conditions - if_arithmetic_true: "if(2+2-4){1} else {0}" => (0., Unit::BASE_UNIT), - if_arithmetic_false: "if(3*2-5){1} else {0}" => (1., Unit::BASE_UNIT), + if_arithmetic_true: "if(2+2-4){1} else {0}" => 0., + if_arithmetic_false: "if(3*2-5){1} else {0}" => 1., // Nested arithmetic - if_complex_arithmetic: "if((5+3)*(2-1)){10} else {20}" => (10., Unit::BASE_UNIT), - if_with_division: "if(8/4-2){15} else {25}" => (15., Unit::BASE_UNIT), + if_complex_arithmetic: "if((5+3)*(2-1)){10} else {20}" => 10., + if_with_division: "if(8/4-2 == 0){15} else {25}" => 15., // Constants in conditions - if_with_pi: "if(pi > 3){1} else {0}" => (1., Unit::BASE_UNIT), - if_with_e: "if(e < 3){1} else {0}" => (0., Unit::BASE_UNIT), + if_with_pi: "if(pi > 3){1} else {0}" => 1., + if_with_e: "if(e < 3){1} else {0}" => 1., // Functions in conditions - if_with_sqrt: "if(sqrt(16) == 4){1} else {0}" => (1., Unit::BASE_UNIT), - if_with_sin: "if(sin(pi)){1} else {0}" => (0., Unit::BASE_UNIT), + if_with_sqrt: "if(sqrt(16) == 4){1} else {0}" => 1., + if_with_sin: "if(sin(pi) == 0.0){1} else {0}" => 0., // Nested if statements - nested_if: "if(1){if(0){1} else {2}} else {3}" => (2., Unit::BASE_UNIT), - nested_if_complex: "if(2-2){if(1){5} else {6}} else {if(1){7} else {8}}" => (7., Unit::BASE_UNIT), - - // If statements with variables - if_with_var: "if(x > 0){1} else {0}" => (1., Unit::BASE_UNIT), - if_var_arithmetic: "if(x + 3 > 5){1} else {0}" => (0., Unit::BASE_UNIT), + nested_if: "if(1){if(0){1} else {2}} else {3}" => 2., + nested_if_complex: "if(2-2 == 0){if(1){5} else {6}} else {if(1){7} else {8}}" => 5., // Mixed operations in conditions and blocks - if_complex_condition: "if(sqrt(16) + sin(pi) < 5){2*pi} else {3*e}" => (2. * std::f64::consts::PI, Unit::BASE_UNIT), - if_complex_blocks: "if(1){2*sqrt(16) + sin(pi/2)} else {3*cos(0) + 4}" => (8., Unit::BASE_UNIT), + if_complex_condition: "if(sqrt(16) + sin(pi) < 5){2*pi} else {3*e}" => 2. * std::f64::consts::PI, + if_complex_blocks: "if(1){2*sqrt(16) + sin(pi/2)} else {3*cos(0) + 4}" => 9., // Edge cases - if_zero: "if(0.0){1} else {2}" => (2., Unit::BASE_UNIT), - if_negative: "if(-1){1} else {2}" => (1., Unit::BASE_UNIT), - if_infinity: "if(inf){1} else {2}" => (1., Unit::BASE_UNIT), + if_zero: "if(0.0){1} else {2}" => 2., + if_negative: "if(-1){1} else {2}" => 1., + if_infinity: "if(inf){1} else {2}" => 1., - // Units in if statements - if_with_units: "if(5m > 3m){10s} else {20s}" => (10., Unit::TIME), - if_mixed_units: "if(2km/h > 1m/s){5kg} else {10kg}" => (10., Unit::MASS), // Complex nested expressions - if_nested_expr: "if((sqrt(16) + 2) * (sin(pi) + 1)){3 + 4 * 2} else {5 - 2 / 1}" => (11., Unit::BASE_UNIT), + if_nested_expr: "if((sqrt(16) + 2) * (sin(pi) + 1)){3 + 4 * 2} else {5 - 2 / 1}" => 11., } } diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index f827914d3a..a45d54f961 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -1,16 +1,14 @@ use crate::ast::{BinaryOp, Literal, Node, UnaryOp, Unit}; use crate::context::EvalContext; use crate::value::{Complex, Number, Value}; +use chumsky::container::Seq; +use chumsky::{Parser, prelude::*}; use lazy_static::lazy_static; use num_complex::ComplexFloat; -use pest::Parser; -use pest::iterators::{Pair, Pairs}; -use pest::pratt_parser::{Assoc, Op, PrattParser}; -use pest_derive::Parser; use std::num::{ParseFloatError, ParseIntError}; use thiserror::Error; -#[derive(Parser)] +/*#[derive(Parser)] #[grammar = "./grammer.pest"] struct ExprParser; @@ -25,7 +23,7 @@ lazy_static! { .op(Op::prefix(Rule::sqrt)) .op(Op::prefix(Rule::neg)) }; -} +}*/ #[derive(Error, Debug)] pub enum TypeError { @@ -37,27 +35,128 @@ pub enum TypeError { } #[derive(Error, Debug)] -pub enum ParseError { - #[error("ParseIntError: {0}")] - ParseInt(#[from] ParseIntError), - #[error("ParseFloatError: {0}")] - ParseFloat(#[from] ParseFloatError), - - #[error("TypeError: {0}")] - Type(#[from] TypeError), - - #[error("PestError: {0}")] - Pest(#[from] Box>), +pub enum ParseError<'src> { + /// One or more syntax/lexing errors produced by Chumsky. + #[error("Syntax error(s): {0:#?}")] + Syntax(Vec>), } impl Node { - pub fn try_parse_from_str(s: &str) -> Result<(Node, Unit), ParseError> { - let pairs = ExprParser::parse(Rule::program, s).map_err(Box::new)?; - let (node, metadata) = parse_expr(pairs)?; - Ok((node, metadata.unit)) + pub fn try_parse_from_str(s: &str) -> Result { + let parsed = chumsky_parser().parse(s); + if parsed.has_output() { + Ok(parsed.into_output().unwrap()) + } else { + Err(ParseError::Syntax(parsed.into_errors())) + } } } +pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Err>> { + recursive(|expr| { + let float = text::int(10) + .then(just('.').map(|c: char| c).then(text::int(10)).or_not()) + .then(just('e').or(just('E')).then(one_of("+-").or_not()).then(text::int(10)).or_not()) + .map(|((int_part, opt_frac), opt_exp): ((&str, _), _)| { + let mut s: String = int_part.to_string(); + if let Some((dot, frac)) = opt_frac { + s.push(dot); + s.push_str(frac); + } + if let Some(((e, sign), exp)) = opt_exp { + s.push(e); + if let Some(sign) = sign { + s.push(sign); + } + s.push_str(exp); + } + Node::Lit(Literal::Float(s.parse().unwrap())) + }); + + let constant = choice(( + just("pi").or(just("π")).map(|_| Node::Lit(Literal::Float(std::f64::consts::PI))), + just("tau").or(just("τ")).map(|_| Node::Lit(Literal::Float(std::f64::consts::TAU))), + just("e").map(|_| Node::Lit(Literal::Float(std::f64::consts::E))), + just("phi").or(just("φ")).map(|_| Node::Lit(Literal::Float(1.618_033_988_75))), + just("inf").or(just("∞")).map(|_| Node::Lit(Literal::Float(f64::INFINITY))), + just("i").map(|_| Node::Lit(Literal::Complex(Complex::new(0.0, 1.0)))), // Assuming Complex impl + just("G").map(|_| Node::Lit(Literal::Float(9.80665))), // gravity + )); + + let ident = text::ident().padded(); + + let var = ident.map(|s: &str| Node::Var(s.to_string())); + + let args = expr.clone().separated_by(just(',')).collect::>().delimited_by(just('('), just(')')); + + let call = ident.then(args).map(|(name, args): (&str, Vec)| Node::FnCall { name: name.to_string(), expr: args }); + + let parens = expr.clone().clone().delimited_by(just('('), just(')')); + + let conditional = just("if") + .padded() + .ignore_then(expr.clone().delimited_by(just('('), just(')'))) + .padded() + .then(expr.clone().delimited_by(just('{'), just('}'))) + .padded() + .then_ignore(just("else")) + .padded() + .then(expr.clone().delimited_by(just('{'), just('}'))) + .padded() + .map(|((cond, if_b), else_b): ((Node, _), _)| Node::Conditional { + condition: Box::new(cond), + if_block: Box::new(if_b), + else_block: Box::new(else_b), + }); + + let atom = choice((conditional, float, constant, call, parens, var)); + + let unary = choice((just('-').to(UnaryOp::Neg), just("sqrt").to(UnaryOp::Sqrt))) + .padded() + .repeated() + .foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }); + + let pow = unary.clone().foldl(just('^').to(BinaryOp::Pow).padded().then(unary).repeated(), |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); + + let product = pow + .clone() + .foldl(choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div))).padded().then(pow).repeated(), |lhs, (op, rhs)| { + Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + } + }); + + let sum = product.clone().foldl( + choice((just('+').to(BinaryOp::Add), just('-').to(BinaryOp::Sub))).padded().then(product).repeated(), + |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }, + ); + + let cmp = sum.clone().foldl( + choice((just("<").to(BinaryOp::Lt), just(">").to(BinaryOp::Gt), just("==").to(BinaryOp::Eq))) + .padded() + .then(sum) + .repeated(), + |lhs: Node, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }, + ); + + cmp.padded() + }) +} + struct NodeMetadata { pub unit: Unit, } @@ -68,7 +167,7 @@ impl NodeMetadata { } } -fn parse_unit(pairs: Pairs) -> Result<(Unit, f64), ParseError> { +/*fn parse_unit(pairs: Pairs) -> Result<(Unit, f64), ParseError> { let mut scale = 1.0; let mut length = 0; let mut mass = 0; @@ -326,7 +425,7 @@ fn parse_expr(pairs: Pairs) -> Result<(Node, NodeMetadata), ParseError> { Ok((node, NodeMetadata::new(unit))) }) .parse(pairs) -} +}*/ //TODO: set up Unit test for Units #[cfg(test)] @@ -338,7 +437,7 @@ mod tests { #[test] fn $name() { let result = Node::try_parse_from_str($input).unwrap(); - assert_eq!(result.0, $expected); + assert_eq!(result, $expected); } )* }; @@ -376,7 +475,7 @@ mod tests { expr: vec![Node::Lit(Literal::Float(16.0))] }, - test_parse_complex_expr: "(1 + 2) 3 - 4 ^ 2" => Node::BinOp { + test_parse_complex_expr: "(1 + 2) * 3 - 4 ^ 2" => Node::BinOp { lhs: Box::new(Node::BinOp { lhs: Box::new(Node::BinOp { lhs: Box::new(Node::Lit(Literal::Float(1.0))), diff --git a/libraries/math-parser/src/value.rs b/libraries/math-parser/src/value.rs index 387959d9f8..2a7065b579 100644 --- a/libraries/math-parser/src/value.rs +++ b/libraries/math-parser/src/value.rs @@ -61,28 +61,16 @@ impl Number { BinaryOp::Mul => lhs * rhs, BinaryOp::Div => lhs / rhs, BinaryOp::Pow => lhs.powf(rhs), - BinaryOp::Leq => { - if lhs > rhs { - 1.0 - } else { - 0.0 - } - } - BinaryOp::Geq => { - if lhs < rhs { - 1.0 - } else { - 0.0 - } - } - BinaryOp::Eq => { - if lhs == rhs { - 1.0 - } else { - 0.0 - } + BinaryOp::Leq => (lhs <= rhs) as u8 as f64, + BinaryOp::Lt => { + println!("{lhs} < {rhs}: {}", (lhs < rhs) as u8); + (lhs < rhs) as u8 as f64 } + BinaryOp::Geq => (lhs >= rhs) as u8 as f64, + BinaryOp::Gt => (lhs > rhs) as u8 as f64, + BinaryOp::Eq => (lhs == rhs) as u8 as f64, }; + Some(Number::Real(result)) } @@ -93,10 +81,7 @@ impl Number { BinaryOp::Mul => lhs * rhs, BinaryOp::Div => lhs / rhs, BinaryOp::Pow => lhs.powc(rhs), - BinaryOp::Leq => { - return None; - } - BinaryOp::Geq => { + BinaryOp::Leq | BinaryOp::Lt | BinaryOp::Geq | BinaryOp::Gt => { return None; } BinaryOp::Eq => { From 2d5c47cc4d0976b2a388c282f6df0a6a022d1299 Mon Sep 17 00:00:00 2001 From: urisinger Date: Tue, 3 Jun 2025 12:11:29 +0300 Subject: [PATCH 03/16] fix benches and improve compile time --- libraries/math-parser/benches/bench.rs | 2 +- libraries/math-parser/src/parser.rs | 335 ++----------------------- 2 files changed, 28 insertions(+), 309 deletions(-) diff --git a/libraries/math-parser/benches/bench.rs b/libraries/math-parser/benches/bench.rs index fd1824c9a0..fc27144491 100644 --- a/libraries/math-parser/benches/bench.rs +++ b/libraries/math-parser/benches/bench.rs @@ -16,7 +16,7 @@ macro_rules! generate_benchmarks { fn evaluation_bench(c: &mut Criterion) { $( - let expr = ast::Node::try_parse_from_str($input).unwrap().0; + let expr = ast::Node::try_parse_from_str($input).unwrap(); let context = EvalContext::default(); c.bench_function(concat!("eval ", $input), |b| { diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index a45d54f961..e301e4f212 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -8,37 +8,26 @@ use num_complex::ComplexFloat; use std::num::{ParseFloatError, ParseIntError}; use thiserror::Error; -/*#[derive(Parser)] -#[grammar = "./grammer.pest"] -struct ExprParser; - -lazy_static! { - static ref PRATT_PARSER: PrattParser = { - PrattParser::new() - .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::sub, Assoc::Left)) - .op(Op::infix(Rule::mul, Assoc::Left) | Op::infix(Rule::div, Assoc::Left) | Op::infix(Rule::paren, Assoc::Left)) - .op(Op::infix(Rule::pow, Assoc::Right)) - .op(Op::infix(Rule::geq, Assoc::Left) | Op::infix(Rule::leq, Assoc::Left) | Op::infix(Rule::eq, Assoc::Left)) - .op(Op::postfix(Rule::EOI)) - .op(Op::prefix(Rule::sqrt)) - .op(Op::prefix(Rule::neg)) - }; -}*/ - -#[derive(Error, Debug)] -pub enum TypeError { - #[error("Invalid BinOp: {0:?} {1:?} {2:?}")] - InvalidBinaryOp(Unit, BinaryOp, Unit), - - #[error("Invalid UnaryOp: {0:?}")] - InvalidUnaryOp(Unit, UnaryOp), -} - #[derive(Error, Debug)] pub enum ParseError<'src> { - /// One or more syntax/lexing errors produced by Chumsky. #[error("Syntax error(s): {0:#?}")] Syntax(Vec>), + #[error("nig")] + Nigga(), +} + +#[cfg(not(debug_assertions))] // In release build: skip boxing +macro_rules! maybe_boxed { + ($p:expr) => { + $p + }; +} + +#[cfg(debug_assertions)] // In debug build: box for faster compilation +macro_rules! maybe_boxed { + ($p:expr) => { + $p.boxed() + }; } impl Node { @@ -109,7 +98,7 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er else_block: Box::new(else_b), }); - let atom = choice((conditional, float, constant, call, parens, var)); + let atom = maybe_boxed!(choice((conditional, float, constant, call, parens, var))); let unary = choice((just('-').to(UnaryOp::Neg), just("sqrt").to(UnaryOp::Sqrt))) .padded() @@ -122,15 +111,16 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er rhs: Box::new(rhs), }); - let product = pow - .clone() - .foldl(choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div))).padded().then(pow).repeated(), |lhs, (op, rhs)| { - Node::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - } - }); + let product = maybe_boxed!( + pow.clone() + .foldl(choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div))).padded().then(pow).repeated(), |lhs, (op, rhs)| { + Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + } + }) + ); let sum = product.clone().foldl( choice((just('+').to(BinaryOp::Add), just('-').to(BinaryOp::Sub))).padded().then(product).repeated(), @@ -157,277 +147,6 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er }) } -struct NodeMetadata { - pub unit: Unit, -} - -impl NodeMetadata { - pub fn new(unit: Unit) -> Self { - Self { unit } - } -} - -/*fn parse_unit(pairs: Pairs) -> Result<(Unit, f64), ParseError> { - let mut scale = 1.0; - let mut length = 0; - let mut mass = 0; - let mut time = 0; - - for pair in pairs { - println!("found rule: {:?}", pair.as_rule()); - match pair.as_rule() { - Rule::nano => scale *= 1e-9, - Rule::micro => scale *= 1e-6, - Rule::milli => scale *= 1e-3, - Rule::centi => scale *= 1e-2, - Rule::deci => scale *= 1e-1, - Rule::deca => scale *= 1e1, - Rule::hecto => scale *= 1e2, - Rule::kilo => scale *= 1e3, - Rule::mega => scale *= 1e6, - Rule::giga => scale *= 1e9, - Rule::tera => scale *= 1e12, - - Rule::meter => length = 1, - Rule::gram => mass = 1, - Rule::second => time = 1, - - _ => unreachable!(), // All possible rules should be covered - } - } - - Ok((Unit { length, mass, time }, scale)) -} - -fn parse_const(pair: Pair) -> Literal { - match pair.as_rule() { - Rule::infinity => Literal::Float(f64::INFINITY), - Rule::imaginary_unit => Literal::Complex(Complex::new(0.0, 1.0)), - Rule::pi => Literal::Float(std::f64::consts::PI), - Rule::tau => Literal::Float(2.0 * std::f64::consts::PI), - Rule::euler_number => Literal::Float(std::f64::consts::E), - Rule::golden_ratio => Literal::Float(1.61803398875), - _ => unreachable!("Unexpected constant: {:?}", pair), - } -} - -fn parse_lit(mut pairs: Pairs) -> Result<(Literal, Unit), ParseError> { - let literal = match pairs.next() { - Some(lit) => match lit.as_rule() { - Rule::int => { - let value = lit.as_str().parse::()? as f64; - Literal::Float(value) - } - Rule::float => { - let value = lit.as_str().parse::()?; - Literal::Float(value) - } - Rule::unit => { - let (unit, scale) = parse_unit(lit.into_inner())?; - return Ok((Literal::Float(scale), unit)); - } - rule => unreachable!("unexpected rule: {:?}", rule), - }, - None => unreachable!("expected rule"), // No literal found - }; - - if let Some(unit_pair) = pairs.next() { - let unit_pairs = unit_pair.into_inner(); // Get the inner pairs for the unit - let (unit, scale) = parse_unit(unit_pairs)?; - - println!("found unit: {:?}", unit); - - Ok(( - match literal { - Literal::Float(num) => Literal::Float(num * scale), - Literal::Complex(num) => Literal::Complex(num * scale), - }, - unit, - )) - } else { - Ok((literal, Unit::BASE_UNIT)) - } -} - -fn parse_expr(pairs: Pairs) -> Result<(Node, NodeMetadata), ParseError> { - PRATT_PARSER - .map_primary(|primary| { - Ok(match primary.as_rule() { - Rule::lit => { - let (lit, unit) = parse_lit(primary.into_inner())?; - - (Node::Lit(lit), NodeMetadata { unit }) - } - Rule::fn_call => { - let mut pairs = primary.into_inner(); - let name = pairs.next().expect("fn_call always has 2 children").as_str().to_string(); - - ( - Node::FnCall { - name, - expr: pairs.map(|p| parse_expr(p.into_inner()).map(|expr| expr.0)).collect::, ParseError>>()?, - }, - NodeMetadata::new(Unit::BASE_UNIT), - ) - } - Rule::constant => { - let lit = parse_const(primary.into_inner().next().expect("constant should have atleast 1 child")); - - (Node::Lit(lit), NodeMetadata::new(Unit::BASE_UNIT)) - } - Rule::ident => { - let name = primary.as_str().to_string(); - - (Node::Var(name), NodeMetadata::new(Unit::BASE_UNIT)) - } - Rule::expr => parse_expr(primary.into_inner())?, - Rule::float => { - let value = primary.as_str().parse::()?; - (Node::Lit(Literal::Float(value)), NodeMetadata::new(Unit::BASE_UNIT)) - } - Rule::conditional => { - let mut inner = primary.into_inner(); - - let condition = parse_expr(inner.next().expect("expected conditional to have condition").into_inner())?; - let if_block = parse_expr(inner.next().expect("expected conditional to have if block").into_inner())?; - let else_block = parse_expr(inner.next().expect("expected conditional to have else block").into_inner())?; - - ( - Node::Conditional { - condition: Box::new(condition.0), - if_block: Box::new(if_block.0), - else_block: Box::new(else_block.0), - }, - if_block.1, - ) - } - rule => unreachable!("unexpected rule: {:?}", rule), - }) - }) - .map_prefix(|op, rhs| { - let (rhs, rhs_metadata) = rhs?; - let op = match op.as_rule() { - Rule::neg => UnaryOp::Neg, - Rule::sqrt => UnaryOp::Sqrt, - - rule => unreachable!("unexpected rule: {:?}", rule), - }; - - let node = Node::UnaryOp { expr: Box::new(rhs), op }; - let unit = rhs_metadata.unit; - - let unit = if !unit.is_base() { - match op { - UnaryOp::Sqrt if unit.length % 2 == 0 && unit.mass % 2 == 0 && unit.time % 2 == 0 => Unit { - length: unit.length / 2, - mass: unit.mass / 2, - time: unit.time / 2, - }, - UnaryOp::Neg => unit, - op => return Err(ParseError::Type(TypeError::InvalidUnaryOp(unit, op))), - } - } else { - Unit::BASE_UNIT - }; - - Ok((node, NodeMetadata::new(unit))) - }) - .map_postfix(|lhs, op| { - let op = match op.as_rule() { - Rule::EOI => return lhs, - rule => unreachable!("unexpected rule: {:?}", rule), - }; - }) - .map_infix(|lhs, op, rhs| { - let (lhs, lhs_metadata) = lhs?; - let (rhs, rhs_metadata) = rhs?; - - let op = match op.as_rule() { - Rule::add => BinaryOp::Add, - Rule::sub => BinaryOp::Sub, - Rule::mul => BinaryOp::Mul, - Rule::div => BinaryOp::Div, - Rule::pow => BinaryOp::Pow, - Rule::paren => BinaryOp::Mul, - Rule::leq => BinaryOp::Leq, - Rule::geq => BinaryOp::Geq, - Rule::eq => BinaryOp::Eq, - rule => unreachable!("unexpected rule: {:?}", rule), - }; - - let (lhs_unit, rhs_unit) = (lhs_metadata.unit, rhs_metadata.unit); - - let unit = match (!lhs_unit.is_base(), !rhs_unit.is_base()) { - (true, true) => match op { - BinaryOp::Mul => Unit { - length: lhs_unit.length + rhs_unit.length, - mass: lhs_unit.mass + rhs_unit.mass, - time: lhs_unit.time + rhs_unit.time, - }, - BinaryOp::Div => Unit { - length: lhs_unit.length - rhs_unit.length, - mass: lhs_unit.mass - rhs_unit.mass, - time: lhs_unit.time - rhs_unit.time, - }, - BinaryOp::Add | BinaryOp::Sub => { - if lhs_unit == rhs_unit { - lhs_unit - } else { - return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); - } - } - BinaryOp::Pow => { - return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); - } - BinaryOp::Eq | BinaryOp::Leq | BinaryOp::Geq => { - if lhs_unit == rhs_unit { - Unit::base_unit() - } else { - return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); - } - } - }, - - (true, false) => match op { - BinaryOp::Add | BinaryOp::Sub => return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))), - BinaryOp::Pow => { - //TODO: improve error type - //TODO: support fractions - if let Ok(Value::Number(Number::Real(val))) = rhs.eval(&EvalContext::default()) { - if (val - val as i32 as f64).abs() <= f64::EPSILON { - Unit { - length: lhs_unit.length * val.round() as i32, - mass: lhs_unit.mass * val.round() as i32, - time: lhs_unit.time * val as i32, - } - } else { - return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))); - } - } else { - return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))); - } - } - _ => lhs_unit, - }, - (false, true) => match op { - BinaryOp::Add | BinaryOp::Sub | BinaryOp::Pow => return Err(ParseError::Type(TypeError::InvalidBinaryOp(Unit::BASE_UNIT, op, rhs_unit))), - _ => rhs_unit, - }, - (false, false) => Unit::BASE_UNIT, - }; - - let node = Node::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - }; - - Ok((node, NodeMetadata::new(unit))) - }) - .parse(pairs) -}*/ - -//TODO: set up Unit test for Units #[cfg(test)] mod tests { use super::*; From af707d9d3a9e0c370b7f5313a23151c68043c6cb Mon Sep 17 00:00:00 2001 From: urisinger Date: Tue, 3 Jun 2025 17:20:23 +0300 Subject: [PATCH 04/16] always box --- libraries/math-parser/src/parser.rs | 38 +++++++++-------------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index e301e4f212..3c56c46c9d 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -12,22 +12,6 @@ use thiserror::Error; pub enum ParseError<'src> { #[error("Syntax error(s): {0:#?}")] Syntax(Vec>), - #[error("nig")] - Nigga(), -} - -#[cfg(not(debug_assertions))] // In release build: skip boxing -macro_rules! maybe_boxed { - ($p:expr) => { - $p - }; -} - -#[cfg(debug_assertions)] // In debug build: box for faster compilation -macro_rules! maybe_boxed { - ($p:expr) => { - $p.boxed() - }; } impl Node { @@ -98,7 +82,7 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er else_block: Box::new(else_b), }); - let atom = maybe_boxed!(choice((conditional, float, constant, call, parens, var))); + let atom = choice((conditional, float, constant, call, parens, var)).boxed(); let unary = choice((just('-').to(UnaryOp::Neg), just("sqrt").to(UnaryOp::Sqrt))) .padded() @@ -111,16 +95,16 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er rhs: Box::new(rhs), }); - let product = maybe_boxed!( - pow.clone() - .foldl(choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div))).padded().then(pow).repeated(), |lhs, (op, rhs)| { - Node::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - } - }) - ); + let product = pow + .clone() + .foldl(choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div))).padded().then(pow).repeated(), |lhs, (op, rhs)| { + Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + } + }) + .boxed(); let sum = product.clone().foldl( choice((just('+').to(BinaryOp::Add), just('-').to(BinaryOp::Sub))).padded().then(product).repeated(), From bc3d77f95b1814a3e8b4510445f5a5b9e9e26aa8 Mon Sep 17 00:00:00 2001 From: urisinger Date: Tue, 3 Jun 2025 20:29:51 +0300 Subject: [PATCH 05/16] fix editor compilation and parser operator precedence --- libraries/math-parser/src/parser.rs | 62 ++++++++++++++--------------- node-graph/gcore/src/ops.rs | 2 +- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index 3c56c46c9d..1250d57789 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -84,12 +84,27 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er let atom = choice((conditional, float, constant, call, parens, var)).boxed(); - let unary = choice((just('-').to(UnaryOp::Neg), just("sqrt").to(UnaryOp::Sqrt))) - .padded() - .repeated() - .foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }); + let add_op = choice((just('+').to(BinaryOp::Add), just('-').to(BinaryOp::Sub))).padded(); + let mul_op = choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div))).padded(); + let pow_op = just('^').to(BinaryOp::Pow).padded(); + let unary_op = choice((just('-').to(UnaryOp::Neg), just("sqrt").to(UnaryOp::Sqrt))).padded(); + let cmp_op = choice(( + just("<").to(BinaryOp::Lt), + just("<=").to(BinaryOp::Leq), + just(">").to(BinaryOp::Gt), + just(">=").to(BinaryOp::Geq), + just("==").to(BinaryOp::Eq), + )); + + let unary = unary_op.repeated().foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }); - let pow = unary.clone().foldl(just('^').to(BinaryOp::Pow).padded().then(unary).repeated(), |lhs, (op, rhs)| Node::BinOp { + let cmp = unary.clone().foldl(cmp_op.padded().then(unary).repeated(), |lhs: Node, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); + + let pow = cmp.clone().foldl(pow_op.then(cmp).repeated(), |lhs, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), op, rhs: Box::new(rhs), @@ -97,37 +112,20 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er let product = pow .clone() - .foldl(choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div))).padded().then(pow).repeated(), |lhs, (op, rhs)| { - Node::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - } - }) - .boxed(); - - let sum = product.clone().foldl( - choice((just('+').to(BinaryOp::Add), just('-').to(BinaryOp::Sub))).padded().then(product).repeated(), - |lhs, (op, rhs)| Node::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - }, - ); - - let cmp = sum.clone().foldl( - choice((just("<").to(BinaryOp::Lt), just(">").to(BinaryOp::Gt), just("==").to(BinaryOp::Eq))) - .padded() - .then(sum) - .repeated(), - |lhs: Node, (op, rhs)| Node::BinOp { + .foldl(mul_op.then(pow).repeated(), |lhs, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), op, rhs: Box::new(rhs), - }, - ); + }) + .boxed(); + + let sum = product.clone().foldl(add_op.then(product).repeated(), |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); - cmp.padded() + sum.padded() }) } diff --git a/node-graph/gcore/src/ops.rs b/node-graph/gcore/src/ops.rs index 99d6b291d3..3ea14373f0 100644 --- a/node-graph/gcore/src/ops.rs +++ b/node-graph/gcore/src/ops.rs @@ -51,7 +51,7 @@ fn math( #[default(1.)] operand_b: U, ) -> U { - let (node, _unit) = match ast::Node::try_parse_from_str(&expression) { + let node = match ast::Node::try_parse_from_str(&expression) { Ok(expr) => expr, Err(e) => { warn!("Invalid expression: `{expression}`\n{e:?}"); From c6f0fcb8caa3a1503d81da3a4a8ef0f715892ef7 Mon Sep 17 00:00:00 2001 From: urisinger Date: Tue, 3 Jun 2025 20:38:27 +0300 Subject: [PATCH 06/16] maybe now its actually fixed --- frontend/wasm/src/editor_api.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/wasm/src/editor_api.rs b/frontend/wasm/src/editor_api.rs index 7284cc3fce..8834cb5ab6 100644 --- a/frontend/wasm/src/editor_api.rs +++ b/frontend/wasm/src/editor_api.rs @@ -964,7 +964,6 @@ pub fn evaluate_math_expression(expression: &str) -> Option { let value = math_parser::evaluate(expression) .inspect_err(|err| error!("Math parser error on \"{expression}\": {err}")) .ok()? - .0 .inspect_err(|err| error!("Math evaluate error on \"{expression}\": {err} ")) .ok()?; let Some(real) = value.as_real() else { From 1e30f44280ce5a82f11528b9321fc37b49986adb Mon Sep 17 00:00:00 2001 From: urisinger Date: Mon, 9 Jun 2025 11:27:22 +0300 Subject: [PATCH 07/16] remove stacker freature from math-parser --- libraries/math-parser/Cargo.toml | 2 +- libraries/math-parser/src/grammer.pest | 62 -------------------------- 2 files changed, 1 insertion(+), 63 deletions(-) delete mode 100644 libraries/math-parser/src/grammer.pest diff --git a/libraries/math-parser/Cargo.toml b/libraries/math-parser/Cargo.toml index 5b13952fc8..38c3cef49a 100644 --- a/libraries/math-parser/Cargo.toml +++ b/libraries/math-parser/Cargo.toml @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" thiserror = "2.0" lazy_static = "1.5" num-complex = "0.4" -chumsky = "0.10" +chumsky = { version = "0.10", features = ["std"] } [dev-dependencies] criterion = "0.5" diff --git a/libraries/math-parser/src/grammer.pest b/libraries/math-parser/src/grammer.pest deleted file mode 100644 index d0aa643aea..0000000000 --- a/libraries/math-parser/src/grammer.pest +++ /dev/null @@ -1,62 +0,0 @@ -WHITESPACE = _{ " " | "\t" } - -// TODO: Proper indentation and formatting -program = _{ SOI ~ expr ~ EOI } - -expr = { atom ~ (infix ~ atom)* } -atom = _{ prefix? ~ primary } -infix = _{ add | sub | mul | div | pow | leq | geq | eq | paren } -add = { "+" } // Addition -sub = { "-" } // Subtraction -mul = { "*" } // Multiplication -div = { "/" } // Division -mod = { "%" } // Modulo -pow = { "^" } // Exponentiation -leq = { "<" } -geq = { ">" } -eq = { "=="} -paren = { "" } // Implicit multiplication operator - -prefix = _{ neg | sqrt } -neg = { "-" } // Negation -sqrt = { "sqrt" } - - -primary = _{ ("(" ~ expr ~ ")") | conditional | lit | constant | fn_call | ident } -conditional = { "if" ~ "(" ~ expr ~ ")" ~ "{" ~ expr ~ "}" ~ "else" ~ "{" ~ expr ~ "}"} -fn_call = { ident ~ "(" ~ expr ~ ("," ~ expr)* ~ ")" } -ident = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } -lit = { unit | ((float | int) ~ unit?) } - -float = @{ int ~ "." ~ int? ~ exp? | int ~ exp } -exp = _{ ^"e" ~ ("+" | "-")? ~ int } -int = @{ ASCII_DIGIT+ } - -unit = ${ (scale ~ base_unit) | base_unit ~ !ident} -base_unit = _{ meter | second | gram } -meter = { "m" } -second = { "s" } -gram = { "g" } - -scale = _{ nano | micro | milli | centi | deci | deca | hecto | kilo | mega | giga | tera } -nano = { "n" } -micro = { "µ" | "u" } -milli = { "m" } -centi = { "c" } -deci = { "d" } -deca = { "da" } -hecto = { "h" } -kilo = { "k" } -mega = { "M" } -giga = { "G" } -tera = { "T" } - -// Constants -constant = { infinity | imaginary_unit | pi | tau | euler_number | golden_ratio | gravity_acceleration } -infinity = { "inf" | "INF" | "infinity" | "INFINITY" | "∞" } -imaginary_unit = { "i" | "I" } -pi = { "pi" | "PI" | "π" } -tau = { "tau" | "TAU" | "τ" } -euler_number = { "e" } -golden_ratio = { "phi" | "PHI" | "φ" } -gravity_acceleration = { "G" } From 01bc7726bd971771016d6e85be2d0ab737cf1830 Mon Sep 17 00:00:00 2001 From: urisinger Date: Mon, 9 Jun 2025 11:48:26 +0300 Subject: [PATCH 08/16] fix again --- Cargo.lock | 23 ----------------------- libraries/math-parser/Cargo.toml | 2 +- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 20b77369c4..463dd7e585 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -752,7 +752,6 @@ dependencies = [ "hashbrown 0.15.2", "regex-automata 0.3.9", "serde", - "stacker", "unicode-ident", "unicode-segmentation", ] @@ -4568,15 +4567,6 @@ version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac" -[[package]] -name = "psm" -version = "0.1.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" -dependencies = [ - "cc", -] - [[package]] name = "publicsuffix" version = "2.3.0" @@ -5757,19 +5747,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" -[[package]] -name = "stacker" -version = "0.1.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" -dependencies = [ - "cc", - "cfg-if", - "libc", - "psm", - "windows-sys 0.59.0", -] - [[package]] name = "static_assertions" version = "1.1.0" diff --git a/libraries/math-parser/Cargo.toml b/libraries/math-parser/Cargo.toml index 38c3cef49a..efedc5ede0 100644 --- a/libraries/math-parser/Cargo.toml +++ b/libraries/math-parser/Cargo.toml @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" thiserror = "2.0" lazy_static = "1.5" num-complex = "0.4" -chumsky = { version = "0.10", features = ["std"] } +chumsky = { version = "0.10", default-features = false, features = ["std"] } [dev-dependencies] criterion = "0.5" From c30a8aa9fc8cf6f64d3f5b6936739551f0d423a8 Mon Sep 17 00:00:00 2001 From: Keavon Chambers Date: Mon, 9 Jun 2025 02:44:39 -0700 Subject: [PATCH 09/16] Code review --- libraries/math-parser/src/lib.rs | 5 ++++- libraries/math-parser/src/parser.rs | 4 ++-- libraries/math-parser/src/value.rs | 5 +---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/libraries/math-parser/src/lib.rs b/libraries/math-parser/src/lib.rs index a8299b3b1e..817183386e 100644 --- a/libraries/math-parser/src/lib.rs +++ b/libraries/math-parser/src/lib.rs @@ -85,8 +85,11 @@ mod tests { } test_end_to_end! { - // Basic arithmetic and units + // Basic arithmetic infix_addition: "5 + 5" => 10., + infix_subtraction_units: "5 - 3" => 2., + infix_multiplication_units: "4 * 4" => 16., + infix_division_units: "8/2" => 4., // Order of operations order_of_operations_negative_prefix: "-10 + 5" => -5., diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index 1250d57789..4385135340 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -52,8 +52,8 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er just("e").map(|_| Node::Lit(Literal::Float(std::f64::consts::E))), just("phi").or(just("φ")).map(|_| Node::Lit(Literal::Float(1.618_033_988_75))), just("inf").or(just("∞")).map(|_| Node::Lit(Literal::Float(f64::INFINITY))), - just("i").map(|_| Node::Lit(Literal::Complex(Complex::new(0.0, 1.0)))), // Assuming Complex impl - just("G").map(|_| Node::Lit(Literal::Float(9.80665))), // gravity + just("i").map(|_| Node::Lit(Literal::Complex(Complex::new(0.0, 1.0)))), // Assuming `Complex` impl + just("G").map(|_| Node::Lit(Literal::Float(9.80665))), // Standard gravity on Earth )); let ident = text::ident().padded(); diff --git a/libraries/math-parser/src/value.rs b/libraries/math-parser/src/value.rs index 2a7065b579..c4e5217ba8 100644 --- a/libraries/math-parser/src/value.rs +++ b/libraries/math-parser/src/value.rs @@ -62,10 +62,7 @@ impl Number { BinaryOp::Div => lhs / rhs, BinaryOp::Pow => lhs.powf(rhs), BinaryOp::Leq => (lhs <= rhs) as u8 as f64, - BinaryOp::Lt => { - println!("{lhs} < {rhs}: {}", (lhs < rhs) as u8); - (lhs < rhs) as u8 as f64 - } + BinaryOp::Lt => (lhs < rhs) as u8 as f64, BinaryOp::Geq => (lhs >= rhs) as u8 as f64, BinaryOp::Gt => (lhs > rhs) as u8 as f64, BinaryOp::Eq => (lhs == rhs) as u8 as f64, From 71b96e851e283227f7e1e5c54f772ce28ec08c7b Mon Sep 17 00:00:00 2001 From: urisinger Date: Mon, 9 Jun 2025 15:40:59 +0300 Subject: [PATCH 10/16] add lexing stage, make variables and functions prefixed to prepare for implicit multiplication --- libraries/math-parser/src/ast.rs | 2 +- libraries/math-parser/src/constants.rs | 8 ++ libraries/math-parser/src/lexer.rs | 159 ++++++++++++++++++++++++ libraries/math-parser/src/lib.rs | 68 +++++------ libraries/math-parser/src/parser.rs | 161 +++++++++++-------------- 5 files changed, 274 insertions(+), 124 deletions(-) create mode 100644 libraries/math-parser/src/lexer.rs diff --git a/libraries/math-parser/src/ast.rs b/libraries/math-parser/src/ast.rs index a73232b4d5..051db8466a 100644 --- a/libraries/math-parser/src/ast.rs +++ b/libraries/math-parser/src/ast.rs @@ -37,7 +37,7 @@ impl Unit { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum Literal { Float(f64), Complex(Complex), diff --git a/libraries/math-parser/src/constants.rs b/libraries/math-parser/src/constants.rs index f439be19d8..cdf3b2439e 100644 --- a/libraries/math-parser/src/constants.rs +++ b/libraries/math-parser/src/constants.rs @@ -9,6 +9,14 @@ lazy_static! { pub static ref DEFAULT_FUNCTIONS: HashMap<&'static str, FunctionImplementation> = { let mut map: HashMap<&'static str, FunctionImplementation> = HashMap::new(); + map.insert( + "sqrt", + Box::new(|values| match values{ + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sqrt()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sqrt()))), + _ => None, + }) + ); map.insert( "sin", Box::new(|values| match values { diff --git a/libraries/math-parser/src/lexer.rs b/libraries/math-parser/src/lexer.rs new file mode 100644 index 0000000000..a6a1dd9e67 --- /dev/null +++ b/libraries/math-parser/src/lexer.rs @@ -0,0 +1,159 @@ +// ── lexer.rs ─────────────────────────────────────────────────────────── +use crate::ast::Literal; +use chumsky::input::{Input, ValueInput}; +use chumsky::prelude::*; +use chumsky::span::SimpleSpan; +use chumsky::text::{ident, int}; +use num_complex::Complex64; +use std::ops::Range; + +pub type Span = SimpleSpan; + +#[derive(Clone, Debug, PartialEq)] +pub enum Token<'src> { + // literals ---------------------------------------------------------------- + Const(Literal), // numeric or complex constants recognised at lex‑time + Var(&'src str), // #identifier (variables) + Call(&'src str), + // punctuation ------------------------------------------------------------- + LParen, + RParen, + Comma, + Plus, + Minus, + Star, + Slash, + Caret, + // comparison -------------------------------------------------------------- + Lt, + Le, + Gt, + Ge, + EqEq, + // keywords ---------------------------------------------------------------- + If, +} + +pub fn lexer<'src>() -> impl Parser<'src, &'src str, Vec<(Token<'src>, Span)>, extra::Err>> { + // ── numbers ──────────────────────────────────────────────────────────── + let num = int(10) + .then(just('.').then(int(10)).or_not()) + .then(just('e').or(just('E')).then(one_of("+-").or_not()).then(int(10)).or_not()) + .map(|((int_part, frac), exp): ((&str, _), _)| { + let mut s = int_part.to_string(); + if let Some((_, frac)) = frac { + s.push('.'); + s.push_str(frac); + } + if let Some(((e, sign), exp)) = exp { + s.push(e); + if let Some(sign) = sign { + s.push(sign); + } + s.push_str(exp); + } + Token::Const(Literal::Float(s.parse::().unwrap())) + }); + + // ── single‑char symbols ──────────────────────────────────────────────── + let sym = choice(( + just('(').to(Token::LParen), + just(')').to(Token::RParen), + just(',').to(Token::Comma), + just('+').to(Token::Plus), + just('-').to(Token::Minus), + just('*').to(Token::Star), + just('/').to(Token::Slash), + just('^').to(Token::Caret), + )); + + // ── comparison operators ─────────────────────────────────────────────── + let cmp = choice(( + just("<=").to(Token::Le), + just(">=").to(Token::Ge), + just("==").to(Token::EqEq), + just('<').to(Token::Lt), + just('>').to(Token::Gt), + )); + + let kw_token = |w, t| just(w).padded().to(t); + + let kw_lit = |w, lit: Literal| just(w).padded().to(lit); + + let const_token = choice(( + kw_lit("pi", Literal::Float(std::f64::consts::PI)), + kw_lit("π", Literal::Float(std::f64::consts::PI)), + kw_lit("tau", Literal::Float(std::f64::consts::TAU)), + kw_lit("τ", Literal::Float(std::f64::consts::TAU)), + kw_lit("e", Literal::Float(std::f64::consts::E)), + kw_lit("phi", Literal::Float(1.618_033_988_75)), + kw_lit("φ", Literal::Float(1.618_033_988_75)), + kw_lit("inf", Literal::Float(f64::INFINITY)), + kw_lit("∞", Literal::Float(f64::INFINITY)), + kw_lit("i", Literal::Complex(Complex64::new(0.0, 1.0))), + kw_lit("G", Literal::Float(9.80665)), + )) + .map(Token::Const); + + let var_token = just('#').ignore_then(ident()).map(Token::Var); + let call_token = just('@').ignore_then(ident()).map(Token::Call); + + choice((num, kw_token("if", Token::If), const_token, cmp, sym, var_token, call_token)) + .map_with(|t, e| (t, e.span())) + .padded() + .repeated() + .collect() +} + +#[derive(Debug)] +pub struct TokenStream<'src> { + tokens: Vec<(Token<'src>, Span)>, +} + +impl<'src> TokenStream<'src> { + pub fn new(tokens: Vec<(Token<'src>, Span)>) -> Self { + TokenStream { tokens } + } +} + +impl<'src> Input<'src> for TokenStream<'src> { + type Token = (Token<'src>, Span); + type Span = Span; + type Cursor = usize; + type MaybeToken = (Token<'src>, Span); + type Cache = Self; + + fn begin(self) -> (Self::Cursor, Self::Cache) { + (0, self) + } + + fn cursor_location(cursor: &Self::Cursor) -> usize { + *cursor + } + + #[inline(always)] + unsafe fn next_maybe(this: &mut Self::Cache, cursor: &mut Self::Cursor) -> Option { + if let Some(tok) = this.tokens.get(*cursor) { + *cursor += 1; + Some(tok.clone()) + } else { + None + } + } + + #[inline(always)] + unsafe fn span(_this: &mut Self::Cache, range: Range<&Self::Cursor>) -> Self::Span { + (*range.start..*range.end).into() + } +} + +impl<'src> ValueInput<'src> for TokenStream<'src> { + unsafe fn next(this: &mut Self::Cache, cursor: &mut Self::Cursor) -> Option { + if let Some(tok) = this.tokens.get(*cursor) { + *cursor += 1; + Some(tok.clone()) + } else { + None + } + } +} diff --git a/libraries/math-parser/src/lib.rs b/libraries/math-parser/src/lib.rs index 817183386e..47f37f1e64 100644 --- a/libraries/math-parser/src/lib.rs +++ b/libraries/math-parser/src/lib.rs @@ -4,6 +4,7 @@ pub mod ast; mod constants; pub mod context; pub mod executer; +pub mod lexer; pub mod parser; pub mod value; @@ -95,8 +96,8 @@ mod tests { order_of_operations_negative_prefix: "-10 + 5" => -5., order_of_operations_add_multiply: "5+1*1+5" => 11., order_of_operations_add_negative_multiply: "5+(-1)*1+5" => 9., - order_of_operations_sqrt: "sqrt25 + 11" => 16., - order_of_operations_sqrt_expression: "sqrt(25+11)" => 6., + order_of_operations_sqrt: "@sqrt(25) + 11" => 16., + order_of_operations_sqrt_expression: "@sqrt(25+11)" => 6., // Parentheses and nested expressions parentheses_nested_multiply: "(5 + 3) * (2 + 6)" => 64., @@ -104,13 +105,13 @@ mod tests { parentheses_divide_add_multiply: "10 / (2 + 3) + (7 * 2)" => 16., // Square root and nested square root - sqrt_chain_operations: "sqrt(16) + sqrt(9) * sqrt(4)" => 10., - sqrt_nested: "sqrt(sqrt(81))" => 3., - sqrt_divide_expression: "sqrt((25 + 11) / 9)" => 2., + sqrt_chain_operations: "@sqrt(16) + @sqrt(9) * @sqrt(4)" => 10., + sqrt_nested: "@sqrt(@sqrt(81))" => 3., + sqrt_divide_expression: "@sqrt((25 + 11) / 9)" => 2., // Mixed square root and units - sqrt_add_multiply: "sqrt(49) - 1 + 2 * 3" => 12., - sqrt_addition_multiply: "(sqrt(36) + 2) * 2" => 16., + sqrt_add_multiply: "@sqrt(49) - 1 + 2 * 3" => 12., + sqrt_addition_multiply: "(@sqrt(36) + 2) * 2" => 16., // Exponentiation exponent_single: "2^3" => 8., @@ -119,10 +120,10 @@ mod tests { // Operations with negative values negative_nested_parentheses: "-(5 + 3 * (2 - 1))" => -8., - negative_sqrt_addition: "-(sqrt(16) + sqrt(9))" => -7., - multiply_sqrt_subtract: "5 * 2 + sqrt(16) / 2 - 3" => 9., - add_multiply_subtract_sqrt: "4 + 3 * (2 + 1) - sqrt(25)" => 8., - add_sqrt_subtract_nested_multiply: "10 + sqrt(64) - (5 * (2 + 1))" => 3., + negative_sqrt_addition: "-(@sqrt(16) + @sqrt(9))" => -7., + multiply_sqrt_subtract: "5 * 2 + @sqrt(16) / 2 - 3" => 9., + add_multiply_subtract_sqrt: "4 + 3 * (2 + 1) - @sqrt(25)" => 8., + add_sqrt_subtract_nested_multiply: "10 + @sqrt(64) - (5 * (2 + 1))" => 3., // Mathematical constants constant_pi: "pi" => std::f64::consts::PI, @@ -138,47 +139,44 @@ mod tests { infinity_subtract_large_number: "inf - 1000" => f64::INFINITY, // Trigonometric functions - trig_sin_pi: "sin(pi)" => 0.0, - trig_cos_zero: "cos(0)" => 1.0, - trig_tan_pi_div_four: "tan(pi/4)" => 1.0, - trig_sin_tau: "sin(tau)" => 0.0, - trig_cos_tau_div_two: "cos(tau/2)" => -1.0, + trig_sin_pi: "@sin(pi)" => 0.0, + trig_cos_zero: "@cos(0)" => 1.0, + trig_tan_pi_div_four: "@tan(pi/4)" => 1.0, + trig_sin_tau: "@sin(tau)" => 0.0, + trig_cos_tau_div_two: "@cos(tau/2)" => -1.0, // Basic if statements - if_true_condition: "if(1){5} else {3}" => 5., - if_false_condition: "if(0){5} else {3}" => 3., + if_true_condition: "if(1,5,3)" => 5., + if_false_condition: "if(0, 5, 3)" => 3., // Arithmetic conditions - if_arithmetic_true: "if(2+2-4){1} else {0}" => 0., - if_arithmetic_false: "if(3*2-5){1} else {0}" => 1., + if_arithmetic_true: "if(2+2-4, 1 , 0)" => 0., + if_arithmetic_false: "if(3*2-5, 1, 0)" => 1., // Nested arithmetic - if_complex_arithmetic: "if((5+3)*(2-1)){10} else {20}" => 10., - if_with_division: "if(8/4-2 == 0){15} else {25}" => 15., + if_complex_arithmetic: "if((5+3)*(2-1), 10, 20)" => 10., + if_with_division: "if(8/4-2 == 0,15, 25)" => 15., // Constants in conditions - if_with_pi: "if(pi > 3){1} else {0}" => 1., - if_with_e: "if(e < 3){1} else {0}" => 1., + if_with_pi: "if(pi > 3, 1, 0)" => 1., + if_with_e: "if(e < 3, 1, 0)" => 1., // Functions in conditions - if_with_sqrt: "if(sqrt(16) == 4){1} else {0}" => 1., - if_with_sin: "if(sin(pi) == 0.0){1} else {0}" => 0., + if_with_sqrt: "if(@sqrt(16) == 4, 1, 0)" => 1., + if_with_sin: "if(@sin(pi) == 0.0, 1, 0)" => 0., // Nested if statements - nested_if: "if(1){if(0){1} else {2}} else {3}" => 2., - nested_if_complex: "if(2-2 == 0){if(1){5} else {6}} else {if(1){7} else {8}}" => 5., + nested_if: "if(1, if(0, 1, 2), 3)" => 2., + nested_if_complex: "if(2-2 == 0, if(1, 5, 6), if(1, 7, 8))" => 5., // Mixed operations in conditions and blocks - if_complex_condition: "if(sqrt(16) + sin(pi) < 5){2*pi} else {3*e}" => 2. * std::f64::consts::PI, - if_complex_blocks: "if(1){2*sqrt(16) + sin(pi/2)} else {3*cos(0) + 4}" => 9., + if_complex_condition: "if(@sqrt(16) + @sin(pi) < 5, 2*pi, 3*e)" => 2. * std::f64::consts::PI, + if_complex_blocks: "if(1, 2*@sqrt(16) + @sin(pi/2), 3*@cos(0) + 4)" => 9., // Edge cases - if_zero: "if(0.0){1} else {2}" => 2., - if_negative: "if(-1){1} else {2}" => 1., - if_infinity: "if(inf){1} else {2}" => 1., - + if_zero: "if(0.0, 1, 2)" => 2., // Complex nested expressions - if_nested_expr: "if((sqrt(16) + 2) * (sin(pi) + 1)){3 + 4 * 2} else {5 - 2 / 1}" => 11., + if_nested_expr: "if((@sqrt(16) + 2) * (@sin(pi) + 1), 3 + 4 * 2, 5 - 2 / 1)" => 11., } } diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index 4385135340..0b3fe7a196 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -1,7 +1,9 @@ use crate::ast::{BinaryOp, Literal, Node, UnaryOp, Unit}; use crate::context::EvalContext; +use crate::lexer::{Span, Token, TokenStream, lexer}; use crate::value::{Complex, Number, Value}; use chumsky::container::Seq; +use chumsky::input::{BorrowInput, ValueInput}; use chumsky::{Parser, prelude::*}; use lazy_static::lazy_static; use num_complex::ComplexFloat; @@ -10,95 +12,81 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum ParseError<'src> { - #[error("Syntax error(s): {0:#?}")] - Syntax(Vec>), + #[error("lexical error(s): {0:#?}")] + Lex(Vec>), + + #[error("syntax error(s): {0:#?}")] + Parse(Vec, Span>>), } impl Node { - pub fn try_parse_from_str(s: &str) -> Result { - let parsed = chumsky_parser().parse(s); - if parsed.has_output() { - Ok(parsed.into_output().unwrap()) - } else { - Err(ParseError::Syntax(parsed.into_errors())) + /// Lex + parse the source and either return an AST `Node` + /// or a typed `ParseError`. + pub fn try_parse_from_str(src: &str) -> Result { + // ── stage 1: lexing ────────────────────────────────────────────── + let (tokens_opt, lex_errs) = lexer().parse(src).into_output_errors(); + if !lex_errs.is_empty() { + return Err(ParseError::Lex(lex_errs)); + } + let tokens = TokenStream::new(tokens_opt.expect("lexer always returns tokens with recovery")); + + match parser().parse(tokens.map((0..src.len()).into(), |(t, s)| (t, s))).into_result() { + Ok(ast) => Ok(ast), + Err(errs) => Err(ParseError::Parse(errs)), } } } -pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Err>> { +pub fn parser<'src, I>() -> impl Parser<'src, I, Node, extra::Err, Span>>> +where + I: ValueInput<'src, Token = Token<'src>, Span = Span>, +{ recursive(|expr| { - let float = text::int(10) - .then(just('.').map(|c: char| c).then(text::int(10)).or_not()) - .then(just('e').or(just('E')).then(one_of("+-").or_not()).then(text::int(10)).or_not()) - .map(|((int_part, opt_frac), opt_exp): ((&str, _), _)| { - let mut s: String = int_part.to_string(); - if let Some((dot, frac)) = opt_frac { - s.push(dot); - s.push_str(frac); - } - if let Some(((e, sign), exp)) = opt_exp { - s.push(e); - if let Some(sign) = sign { - s.push(sign); - } - s.push_str(exp); - } - Node::Lit(Literal::Float(s.parse().unwrap())) - }); - - let constant = choice(( - just("pi").or(just("π")).map(|_| Node::Lit(Literal::Float(std::f64::consts::PI))), - just("tau").or(just("τ")).map(|_| Node::Lit(Literal::Float(std::f64::consts::TAU))), - just("e").map(|_| Node::Lit(Literal::Float(std::f64::consts::E))), - just("phi").or(just("φ")).map(|_| Node::Lit(Literal::Float(1.618_033_988_75))), - just("inf").or(just("∞")).map(|_| Node::Lit(Literal::Float(f64::INFINITY))), - just("i").map(|_| Node::Lit(Literal::Complex(Complex::new(0.0, 1.0)))), // Assuming `Complex` impl - just("G").map(|_| Node::Lit(Literal::Float(9.80665))), // Standard gravity on Earth - )); - - let ident = text::ident().padded(); - - let var = ident.map(|s: &str| Node::Var(s.to_string())); - - let args = expr.clone().separated_by(just(',')).collect::>().delimited_by(just('('), just(')')); - - let call = ident.then(args).map(|(name, args): (&str, Vec)| Node::FnCall { name: name.to_string(), expr: args }); - - let parens = expr.clone().clone().delimited_by(just('('), just(')')); - - let conditional = just("if") - .padded() - .ignore_then(expr.clone().delimited_by(just('('), just(')'))) - .padded() - .then(expr.clone().delimited_by(just('{'), just('}'))) - .padded() - .then_ignore(just("else")) - .padded() - .then(expr.clone().delimited_by(just('{'), just('}'))) - .padded() - .map(|((cond, if_b), else_b): ((Node, _), _)| Node::Conditional { - condition: Box::new(cond), - if_block: Box::new(if_b), - else_block: Box::new(else_b), - }); - - let atom = choice((conditional, float, constant, call, parens, var)).boxed(); - - let add_op = choice((just('+').to(BinaryOp::Add), just('-').to(BinaryOp::Sub))).padded(); - let mul_op = choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div))).padded(); - let pow_op = just('^').to(BinaryOp::Pow).padded(); - let unary_op = choice((just('-').to(UnaryOp::Neg), just("sqrt").to(UnaryOp::Sqrt))).padded(); + let constant = select! {Token::Const(x) => Node::Lit(x)}; + + let args = expr.clone().separated_by(just(Token::Comma)).collect::>().delimited_by(just(Token::LParen), just(Token::RParen)); + + let if_expr = just(Token::If) + .ignore_then(args.clone()) // Parses (cond, a, b) + .try_map(|args: Vec, span| { + if args.len() != 3 { + return Err(Rich::custom(span, "Expected 3 arguments in if(cond, a, b)")); + } + let mut iter = args.into_iter(); + let cond = iter.next().unwrap(); + let if_b = iter.next().unwrap(); + let else_b = iter.next().unwrap(); + Ok(Node::Conditional { + condition: Box::new(cond), + if_block: Box::new(if_b), + else_block: Box::new(else_b), + }) + }); + + let call = select! {Token::Call(s) => s} + .then(args) + .try_map(|(name, args): (&str, Vec), span| Ok(Node::FnCall { name: name.to_string(), expr: args })); + + let parens = expr.clone().clone().delimited_by(just(Token::LParen), just(Token::RParen)); + let var = select! { Token::Var(name) => Node::Var(name.to_string()) }; + + let atom = choice((constant, if_expr, call, parens, var)).boxed(); + + let add_op = choice((just(Token::Plus).to(BinaryOp::Add), just(Token::Minus).to(BinaryOp::Sub))); + let mul_op = choice((just(Token::Star).to(BinaryOp::Mul), just(Token::Slash).to(BinaryOp::Div))); + let pow_op = just(Token::Caret).to(BinaryOp::Pow); + let unary_op = just(Token::Minus).to(UnaryOp::Neg); let cmp_op = choice(( - just("<").to(BinaryOp::Lt), - just("<=").to(BinaryOp::Leq), - just(">").to(BinaryOp::Gt), - just(">=").to(BinaryOp::Geq), - just("==").to(BinaryOp::Eq), + just(Token::Lt).to(BinaryOp::Lt), + just(Token::Le).to(BinaryOp::Leq), + just(Token::Gt).to(BinaryOp::Gt), + just(Token::Ge).to(BinaryOp::Geq), + just(Token::EqEq).to(BinaryOp::Eq), )); let unary = unary_op.repeated().foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }); - let cmp = unary.clone().foldl(cmp_op.padded().then(unary).repeated(), |lhs: Node, (op, rhs)| Node::BinOp { + let cmp = unary.clone().foldl(cmp_op.then(unary).repeated(), |lhs: Node, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), op, rhs: Box::new(rhs), @@ -119,13 +107,11 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er }) .boxed(); - let sum = product.clone().foldl(add_op.then(product).repeated(), |lhs, (op, rhs)| Node::BinOp { + product.clone().foldl(add_op.then(product).repeated(), |lhs, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), op, rhs: Box::new(rhs), - }); - - sum.padded() + }) }) } @@ -147,7 +133,7 @@ mod tests { test_parser! { test_parse_int_literal: "42" => Node::Lit(Literal::Float(42.0)), test_parse_float_literal: "3.14" => Node::Lit(Literal::Float(#[allow(clippy::approx_constant)] 3.14)), - test_parse_ident: "x" => Node::Var("x".to_string()), + test_parse_ident: "#x" => Node::Var("x".to_string()), test_parse_unary_neg: "-42" => Node::UnaryOp { expr: Box::new(Node::Lit(Literal::Float(42.0))), op: UnaryOp::Neg, @@ -167,15 +153,14 @@ mod tests { op: BinaryOp::Pow, rhs: Box::new(Node::Lit(Literal::Float(3.0))), }, - test_parse_unary_sqrt: "sqrt(16)" => Node::UnaryOp { - expr: Box::new(Node::Lit(Literal::Float(16.0))), - op: UnaryOp::Sqrt, + test_parse_unary_sqrt: "@sqrt(16)" => Node::FnCall { + name: "sqrt".to_string(), + expr: vec![Node::Lit(Literal::Float(16.0))], }, - test_parse_sqr_ident: "sqr(16)" => Node::FnCall { - name:"sqr".to_string(), + test_parse_i_call: "@i(16)" => Node::FnCall { + name:"i".to_string(), expr: vec![Node::Lit(Literal::Float(16.0))] }, - test_parse_complex_expr: "(1 + 2) * 3 - 4 ^ 2" => Node::BinOp { lhs: Box::new(Node::BinOp { lhs: Box::new(Node::BinOp { @@ -193,7 +178,7 @@ mod tests { rhs: Box::new(Node::Lit(Literal::Float(2.0))), }), }, - test_conditional_expr: "if (x+3) {0} else {1}" => Node::Conditional{ + test_conditional_expr: "if (#x+3, 0, 1)" => Node::Conditional{ condition: Box::new(Node::BinOp{ lhs: Box::new(Node::Var("x".to_string())), op: BinaryOp::Add, From b38ffa9c20af9b47756e4a3abf4bd2b21eda46d0 Mon Sep 17 00:00:00 2001 From: urisinger Date: Mon, 9 Jun 2025 15:48:05 +0300 Subject: [PATCH 11/16] fix benches --- libraries/math-parser/benches/bench.rs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/libraries/math-parser/benches/bench.rs b/libraries/math-parser/benches/bench.rs index fc27144491..e2778b91c4 100644 --- a/libraries/math-parser/benches/bench.rs +++ b/libraries/math-parser/benches/bench.rs @@ -33,17 +33,17 @@ macro_rules! generate_benchmarks { } generate_benchmarks! { - "(3 * (4 + sqrt(25)) - cos(pi/3) * (2^3)) + 5 * e", // Mixed nested functions, constants, and operations - "((5 + 2 * (3 - sqrt(49)))^2) / (1 + sqrt(16)) + tau / 2", // Complex nested expression with constants - "log(100, 10) + (5 * sin(pi/4) + sqrt(81)) / (2 * phi)", // Logarithmic and trigonometric functions - "(sqrt(144) * 2 + 5) / (3 * (4 - sin(pi / 6))) + e^2", // Combined square root, trigonometric, and exponential operations - "cos(2 * pi) + tan(pi / 3) * log(32, 2) - sqrt(256)", // Multiple trigonometric and logarithmic functions - "(10 * (3 + 2) - 8 / 2)^2 + 7 * (2^4) - sqrt(225) + phi", // Mixed arithmetic with constants - "(5^2 + 3^3) * (sqrt(81) + sqrt(64)) - tau * log(1000, 10)", // Power and square root with constants - "((8 * sqrt(49) - 2 * e) + log(256, 2) / (2 + cos(pi))) * 1.5", // Nested functions and constants - "(tan(pi / 4) + 5) * (3 + sqrt(36)) / (log(1024, 2) - 4)", // Nested functions with trigonometry and logarithm - "((3 * e + 2 * sqrt(100)) - cos(tau / 4)) * log(27, 3) + phi", // Mixed constant usage and functions - "(sqrt(100) + 5 * sin(pi / 6) - 8 / log(64, 2)) + e^(1.5)", // Complex mix of square root, division, and exponentiation - "((sin(pi/2) + cos(0)) * (e^2 - 2 * sqrt(16))) / (log(100, 10) + pi)", // Nested trigonometric, exponential, and logarithmic functions - "(5 * (7 + sqrt(121)) - (log(243, 3) * phi)) + 3^5 / tau", // + "(3 * (4 + @sqrt(25)) - @cos(pi/3) * (2^3)) + 5 * e", // Mixed nested functions, constants, and operations + "((5 + 2 * (3 - @sqrt(49)))^2) / (1 + @sqrt(16)) + tau / 2", // Complex nested expression with constants + "@log(100, 10) + (5 * @sin(pi/4) + @sqrt(81)) / (2 * phi)", // Logarithmic and trigonometric functions + "(@sqrt(144) * 2 + 5) / (3 * (4 - @sin(pi / 6))) + e^2", // Combined square root, trigonometric, and exponential operations + "@cos(2 * pi) + @tan(pi / 3) * @log(32, 2) - @sqrt(256)", // Multiple trigonometric and logarithmic functions + "(10 * (3 + 2) - 8 / 2)^2 + 7 * (2^4) - @sqrt(225) + phi", // Mixed arithmetic with constants + "(5^2 + 3^3) * (@sqrt(81) + @sqrt(64)) - tau * @log(1000, 10)", // Power and square root with constants + "((8 * @sqrt(49) - 2 * e) + @log(256, 2) / (2 + @cos(pi))) * 1.5", // Nested functions and constants + "(tan(pi / 4) + 5) * (3 + @sqrt(36)) / (@log(1024, 2) - 4)", // Nested functions with trigonometry and logarithm + "((3 * e + 2 * @sqrt(100)) - @cos(tau / 4)) * @log(27, 3) + phi", // Mixed constant usage and functions + "(@sqrt(100) + 5 * @sin(pi / 6) - 8 / log(64, 2)) + e^(1.5)", // Complex mix of square root, division, and exponentiation + "((@sin(pi/2) + @cos(0)) * (e^2 - 2 * @sqrt(16))) / (@log(100, 10) + pi)", // Nested trigonometric, exponential, and logarithmic functions + "(5 * (7 + @sqrt(121)) - (@log(243, 3) * phi)) + 3^5 / tau", // } From 6cc07454d135754ebc8cd4eefbc1c37fb0c4570c Mon Sep 17 00:00:00 2001 From: urisinger Date: Sun, 29 Jun 2025 15:46:10 +0300 Subject: [PATCH 12/16] add implicit multiplication and make lexer much faster --- libraries/math-parser/benches/bench.rs | 26 +-- libraries/math-parser/src/lexer.rs | 272 +++++++++++++++++-------- libraries/math-parser/src/lib.rs | 44 ++-- libraries/math-parser/src/parser.rs | 45 ++-- 4 files changed, 242 insertions(+), 145 deletions(-) diff --git a/libraries/math-parser/benches/bench.rs b/libraries/math-parser/benches/bench.rs index e2778b91c4..fc27144491 100644 --- a/libraries/math-parser/benches/bench.rs +++ b/libraries/math-parser/benches/bench.rs @@ -33,17 +33,17 @@ macro_rules! generate_benchmarks { } generate_benchmarks! { - "(3 * (4 + @sqrt(25)) - @cos(pi/3) * (2^3)) + 5 * e", // Mixed nested functions, constants, and operations - "((5 + 2 * (3 - @sqrt(49)))^2) / (1 + @sqrt(16)) + tau / 2", // Complex nested expression with constants - "@log(100, 10) + (5 * @sin(pi/4) + @sqrt(81)) / (2 * phi)", // Logarithmic and trigonometric functions - "(@sqrt(144) * 2 + 5) / (3 * (4 - @sin(pi / 6))) + e^2", // Combined square root, trigonometric, and exponential operations - "@cos(2 * pi) + @tan(pi / 3) * @log(32, 2) - @sqrt(256)", // Multiple trigonometric and logarithmic functions - "(10 * (3 + 2) - 8 / 2)^2 + 7 * (2^4) - @sqrt(225) + phi", // Mixed arithmetic with constants - "(5^2 + 3^3) * (@sqrt(81) + @sqrt(64)) - tau * @log(1000, 10)", // Power and square root with constants - "((8 * @sqrt(49) - 2 * e) + @log(256, 2) / (2 + @cos(pi))) * 1.5", // Nested functions and constants - "(tan(pi / 4) + 5) * (3 + @sqrt(36)) / (@log(1024, 2) - 4)", // Nested functions with trigonometry and logarithm - "((3 * e + 2 * @sqrt(100)) - @cos(tau / 4)) * @log(27, 3) + phi", // Mixed constant usage and functions - "(@sqrt(100) + 5 * @sin(pi / 6) - 8 / log(64, 2)) + e^(1.5)", // Complex mix of square root, division, and exponentiation - "((@sin(pi/2) + @cos(0)) * (e^2 - 2 * @sqrt(16))) / (@log(100, 10) + pi)", // Nested trigonometric, exponential, and logarithmic functions - "(5 * (7 + @sqrt(121)) - (@log(243, 3) * phi)) + 3^5 / tau", // + "(3 * (4 + sqrt(25)) - cos(pi/3) * (2^3)) + 5 * e", // Mixed nested functions, constants, and operations + "((5 + 2 * (3 - sqrt(49)))^2) / (1 + sqrt(16)) + tau / 2", // Complex nested expression with constants + "log(100, 10) + (5 * sin(pi/4) + sqrt(81)) / (2 * phi)", // Logarithmic and trigonometric functions + "(sqrt(144) * 2 + 5) / (3 * (4 - sin(pi / 6))) + e^2", // Combined square root, trigonometric, and exponential operations + "cos(2 * pi) + tan(pi / 3) * log(32, 2) - sqrt(256)", // Multiple trigonometric and logarithmic functions + "(10 * (3 + 2) - 8 / 2)^2 + 7 * (2^4) - sqrt(225) + phi", // Mixed arithmetic with constants + "(5^2 + 3^3) * (sqrt(81) + sqrt(64)) - tau * log(1000, 10)", // Power and square root with constants + "((8 * sqrt(49) - 2 * e) + log(256, 2) / (2 + cos(pi))) * 1.5", // Nested functions and constants + "(tan(pi / 4) + 5) * (3 + sqrt(36)) / (log(1024, 2) - 4)", // Nested functions with trigonometry and logarithm + "((3 * e + 2 * sqrt(100)) - cos(tau / 4)) * log(27, 3) + phi", // Mixed constant usage and functions + "(sqrt(100) + 5 * sin(pi / 6) - 8 / log(64, 2)) + e^(1.5)", // Complex mix of square root, division, and exponentiation + "((sin(pi/2) + cos(0)) * (e^2 - 2 * sqrt(16))) / (log(100, 10) + pi)", // Nested trigonometric, exponential, and logarithmic functions + "(5 * (7 + sqrt(121)) - (log(243, 3) * phi)) + 3^5 / tau", // } diff --git a/libraries/math-parser/src/lexer.rs b/libraries/math-parser/src/lexer.rs index a6a1dd9e67..9bce2a5656 100644 --- a/libraries/math-parser/src/lexer.rs +++ b/libraries/math-parser/src/lexer.rs @@ -4,8 +4,11 @@ use chumsky::input::{Input, ValueInput}; use chumsky::prelude::*; use chumsky::span::SimpleSpan; use chumsky::text::{ident, int}; +use core::f64; use num_complex::Complex64; +use std::iter::Peekable; use std::ops::Range; +use std::str::Chars; pub type Span = SimpleSpan; @@ -13,8 +16,7 @@ pub type Span = SimpleSpan; pub enum Token<'src> { // literals ---------------------------------------------------------------- Const(Literal), // numeric or complex constants recognised at lex‑time - Var(&'src str), // #identifier (variables) - Call(&'src str), + Ident(&'src str), // punctuation ------------------------------------------------------------- LParen, RParen, @@ -34,124 +36,216 @@ pub enum Token<'src> { If, } -pub fn lexer<'src>() -> impl Parser<'src, &'src str, Vec<(Token<'src>, Span)>, extra::Err>> { - // ── numbers ──────────────────────────────────────────────────────────── - let num = int(10) - .then(just('.').then(int(10)).or_not()) - .then(just('e').or(just('E')).then(one_of("+-").or_not()).then(int(10)).or_not()) - .map(|((int_part, frac), exp): ((&str, _), _)| { - let mut s = int_part.to_string(); - if let Some((_, frac)) = frac { - s.push('.'); - s.push_str(frac); +fn const_lit(name: &str) -> Option { + use std::f64::consts::*; + + Some(match name { + "pi" | "π" => Literal::Float(PI), + "tau" | "τ" => Literal::Float(TAU), + "e" => Literal::Float(E), + "phi" | "φ" => Literal::Float(1.618_033_988_75), + "inf" | "∞" => Literal::Float(f64::INFINITY), + "i" => Literal::Complex(Complex64::new(0.0, 1.0)), + "G" => Literal::Float(9.80665), + _ => return None, + }) +} + +pub struct Lexer<'a> { + input: &'a str, + pos: usize, +} + +impl<'a> Lexer<'a> { + pub fn new(input: &'a str) -> Self { + Self { input, pos: 0 } + } + + fn peek(&self) -> Option { + self.input[self.pos..].chars().next() + } + + fn bump(&mut self) -> Option { + let c = self.peek()?; + self.pos += c.len_utf8(); + Some(c) + } + + fn consume_while(&mut self, cond: F) -> &'a str + where + F: Fn(char) -> bool, + { + let start = self.pos; + while self.peek().is_some_and(&cond) { + self.bump(); + } + &self.input[start..self.pos] + } + + fn lex_ident(&mut self) -> &'a str { + self.consume_while(|c| c.is_alphanumeric() || c == '_') + } + + fn lex_uint(&mut self) -> Option<(u64, usize)> { + let mut v = 0u64; + let mut digits = 0; + while let Some(d) = self.peek().and_then(|c| c.to_digit(10)) { + v = v * 10 + d as u64; + digits += 1; + self.bump(); + } + (digits > 0).then_some((v, digits)) + } + + fn lex_number(&mut self) -> Option { + let start_pos = self.pos; + let (int_val, int_digits) = self.lex_uint().unwrap_or((0, 0)); + let mut got_digit = int_digits > 0; + let mut num = int_val as f64; + + if self.peek() == Some('.') { + self.bump(); + if let Some((frac_val, frac_digits)) = self.lex_uint() { + num += (frac_val as f64) / 10f64.powi(frac_digits as i32); + got_digit = true; } - if let Some(((e, sign), exp)) = exp { - s.push(e); - if let Some(sign) = sign { - s.push(sign); + } + + if matches!(self.peek(), Some('e' | 'E')) { + self.bump(); + let sign = match self.peek() { + Some('+') => { + self.bump(); + 1 } - s.push_str(exp); + Some('-') => { + self.bump(); + -1 + } + _ => 1, + }; + if let Some((exp_val, _)) = self.lex_uint() { + num *= 10f64.powi(sign * exp_val as i32); + } else { + self.pos = start_pos; + return None; } - Token::Const(Literal::Float(s.parse::().unwrap())) - }); - - // ── single‑char symbols ──────────────────────────────────────────────── - let sym = choice(( - just('(').to(Token::LParen), - just(')').to(Token::RParen), - just(',').to(Token::Comma), - just('+').to(Token::Plus), - just('-').to(Token::Minus), - just('*').to(Token::Star), - just('/').to(Token::Slash), - just('^').to(Token::Caret), - )); - - // ── comparison operators ─────────────────────────────────────────────── - let cmp = choice(( - just("<=").to(Token::Le), - just(">=").to(Token::Ge), - just("==").to(Token::EqEq), - just('<').to(Token::Lt), - just('>').to(Token::Gt), - )); - - let kw_token = |w, t| just(w).padded().to(t); - - let kw_lit = |w, lit: Literal| just(w).padded().to(lit); - - let const_token = choice(( - kw_lit("pi", Literal::Float(std::f64::consts::PI)), - kw_lit("π", Literal::Float(std::f64::consts::PI)), - kw_lit("tau", Literal::Float(std::f64::consts::TAU)), - kw_lit("τ", Literal::Float(std::f64::consts::TAU)), - kw_lit("e", Literal::Float(std::f64::consts::E)), - kw_lit("phi", Literal::Float(1.618_033_988_75)), - kw_lit("φ", Literal::Float(1.618_033_988_75)), - kw_lit("inf", Literal::Float(f64::INFINITY)), - kw_lit("∞", Literal::Float(f64::INFINITY)), - kw_lit("i", Literal::Complex(Complex64::new(0.0, 1.0))), - kw_lit("G", Literal::Float(9.80665)), - )) - .map(Token::Const); - - let var_token = just('#').ignore_then(ident()).map(Token::Var); - let call_token = just('@').ignore_then(ident()).map(Token::Call); - - choice((num, kw_token("if", Token::If), const_token, cmp, sym, var_token, call_token)) - .map_with(|t, e| (t, e.span())) - .padded() - .repeated() - .collect() -} + } -#[derive(Debug)] -pub struct TokenStream<'src> { - tokens: Vec<(Token<'src>, Span)>, -} + got_digit.then_some(num) + } + + fn skip_ws(&mut self) { + self.consume_while(char::is_whitespace); + } + + pub fn next_token(&mut self) -> Option> { + self.skip_ws(); + let start = self.pos; + let ch = self.bump()?; + + use Token::*; + let tok = match ch { + '(' => LParen, + ')' => RParen, + ',' => Comma, + '+' => Plus, + '-' => Minus, + '*' => Star, + '/' => Slash, + '^' => Caret, + + '<' => { + if self.peek() == Some('=') { + self.bump(); + Le + } else { + Lt + } + } + '>' => { + if self.peek() == Some('=') { + self.bump(); + Ge + } else { + Gt + } + } + '=' => { + if self.peek() == Some('=') { + self.bump(); + EqEq + } else { + return None; + } + } + + c if c.is_ascii_digit() || (c == '.' && self.peek().is_some_and(|c| c.is_ascii_digit())) => { + self.pos = start; + Const(Literal::Float(self.lex_number()?)) + } + + _ => { + self.consume_while(|c| c.is_alphanumeric() || c == '_'); + let ident = &self.input[start..self.pos]; + + if ident == "if" { + If + } else if let Some(lit) = const_lit(ident) { + Const(lit) + } else if ch.is_alphanumeric() { + Ident(ident) + } else { + return None; + } + } + }; -impl<'src> TokenStream<'src> { - pub fn new(tokens: Vec<(Token<'src>, Span)>) -> Self { - TokenStream { tokens } + Some(tok) } } -impl<'src> Input<'src> for TokenStream<'src> { - type Token = (Token<'src>, Span); +impl<'src> Input<'src> for Lexer<'src> { + type Token = Token<'src>; type Span = Span; - type Cursor = usize; - type MaybeToken = (Token<'src>, Span); + type Cursor = usize; // byte offset inside `input` + type MaybeToken = Token<'src>; type Cache = Self; + #[inline] fn begin(self) -> (Self::Cursor, Self::Cache) { (0, self) } + #[inline] fn cursor_location(cursor: &Self::Cursor) -> usize { *cursor } - #[inline(always)] + #[inline] unsafe fn next_maybe(this: &mut Self::Cache, cursor: &mut Self::Cursor) -> Option { - if let Some(tok) = this.tokens.get(*cursor) { - *cursor += 1; - Some(tok.clone()) + this.pos = *cursor; + if let Some(tok) = this.next_token() { + *cursor = this.pos; + Some(tok) } else { None } } - #[inline(always)] + #[inline] unsafe fn span(_this: &mut Self::Cache, range: Range<&Self::Cursor>) -> Self::Span { (*range.start..*range.end).into() } } -impl<'src> ValueInput<'src> for TokenStream<'src> { +impl<'src> ValueInput<'src> for Lexer<'src> { + #[inline] unsafe fn next(this: &mut Self::Cache, cursor: &mut Self::Cursor) -> Option { - if let Some(tok) = this.tokens.get(*cursor) { - *cursor += 1; - Some(tok.clone()) + this.pos = *cursor; + if let Some(tok) = this.next_token() { + *cursor = this.pos; + Some(tok) } else { None } diff --git a/libraries/math-parser/src/lib.rs b/libraries/math-parser/src/lib.rs index 47f37f1e64..e360ef0cce 100644 --- a/libraries/math-parser/src/lib.rs +++ b/libraries/math-parser/src/lib.rs @@ -14,7 +14,7 @@ use executer::EvalError; use parser::ParseError; use value::Value; -pub fn evaluate(expression: &str) -> Result, ParseError> { +pub fn evaluate(expression: &str) -> Result, ParseError<'_>> { let expr = ast::Node::try_parse_from_str(expression); let context = EvalContext::default(); expr.map(|node| node.eval(&context)) @@ -96,8 +96,8 @@ mod tests { order_of_operations_negative_prefix: "-10 + 5" => -5., order_of_operations_add_multiply: "5+1*1+5" => 11., order_of_operations_add_negative_multiply: "5+(-1)*1+5" => 9., - order_of_operations_sqrt: "@sqrt(25) + 11" => 16., - order_of_operations_sqrt_expression: "@sqrt(25+11)" => 6., + order_of_operations_sqrt: "sqrt(25) + 11" => 16., + order_of_operations_sqrt_expression: "sqrt(25+11)" => 6., // Parentheses and nested expressions parentheses_nested_multiply: "(5 + 3) * (2 + 6)" => 64., @@ -105,13 +105,13 @@ mod tests { parentheses_divide_add_multiply: "10 / (2 + 3) + (7 * 2)" => 16., // Square root and nested square root - sqrt_chain_operations: "@sqrt(16) + @sqrt(9) * @sqrt(4)" => 10., - sqrt_nested: "@sqrt(@sqrt(81))" => 3., - sqrt_divide_expression: "@sqrt((25 + 11) / 9)" => 2., + sqrt_chain_operations: "sqrt(16) + sqrt(9) * sqrt(4)" => 10., + sqrt_nested: "sqrt(sqrt(81))" => 3., + sqrt_divide_expression: "sqrt((25 + 11) / 9)" => 2., // Mixed square root and units - sqrt_add_multiply: "@sqrt(49) - 1 + 2 * 3" => 12., - sqrt_addition_multiply: "(@sqrt(36) + 2) * 2" => 16., + sqrt_add_multiply: "sqrt(49) - 1 + 2 * 3" => 12., + sqrt_addition_multiply: "(sqrt(36) + 2) * 2" => 16., // Exponentiation exponent_single: "2^3" => 8., @@ -120,10 +120,10 @@ mod tests { // Operations with negative values negative_nested_parentheses: "-(5 + 3 * (2 - 1))" => -8., - negative_sqrt_addition: "-(@sqrt(16) + @sqrt(9))" => -7., - multiply_sqrt_subtract: "5 * 2 + @sqrt(16) / 2 - 3" => 9., - add_multiply_subtract_sqrt: "4 + 3 * (2 + 1) - @sqrt(25)" => 8., - add_sqrt_subtract_nested_multiply: "10 + @sqrt(64) - (5 * (2 + 1))" => 3., + negative_sqrt_addition: "-(sqrt(16) + sqrt(9))" => -7., + multiply_sqrt_subtract: "5 * 2 + sqrt(16) / 2 - 3" => 9., + add_multiply_subtract_sqrt: "4 + 3 * (2 + 1) - sqrt(25)" => 8., + add_sqrt_subtract_nested_multiply: "10 + sqrt(64) - (5 * (2 + 1))" => 3., // Mathematical constants constant_pi: "pi" => std::f64::consts::PI, @@ -139,11 +139,11 @@ mod tests { infinity_subtract_large_number: "inf - 1000" => f64::INFINITY, // Trigonometric functions - trig_sin_pi: "@sin(pi)" => 0.0, - trig_cos_zero: "@cos(0)" => 1.0, - trig_tan_pi_div_four: "@tan(pi/4)" => 1.0, - trig_sin_tau: "@sin(tau)" => 0.0, - trig_cos_tau_div_two: "@cos(tau/2)" => -1.0, + trig_sin_pi: "sin(pi)" => 0.0, + trig_cos_zero: "cos(0)" => 1.0, + trig_tan_pi_div_four: "tan(pi/4)" => 1.0, + trig_sin_tau: "sin(tau)" => 0.0, + trig_cos_tau_div_two: "cos(tau/2)" => -1.0, // Basic if statements if_true_condition: "if(1,5,3)" => 5., @@ -162,21 +162,21 @@ mod tests { if_with_e: "if(e < 3, 1, 0)" => 1., // Functions in conditions - if_with_sqrt: "if(@sqrt(16) == 4, 1, 0)" => 1., - if_with_sin: "if(@sin(pi) == 0.0, 1, 0)" => 0., + if_with_sqrt: "if(sqrt(16) == 4, 1, 0)" => 1., + if_with_sin: "if(sin(pi) == 0.0, 1, 0)" => 0., // Nested if statements nested_if: "if(1, if(0, 1, 2), 3)" => 2., nested_if_complex: "if(2-2 == 0, if(1, 5, 6), if(1, 7, 8))" => 5., // Mixed operations in conditions and blocks - if_complex_condition: "if(@sqrt(16) + @sin(pi) < 5, 2*pi, 3*e)" => 2. * std::f64::consts::PI, - if_complex_blocks: "if(1, 2*@sqrt(16) + @sin(pi/2), 3*@cos(0) + 4)" => 9., + if_complex_condition: "if(sqrt(16) + sin(pi) < 5, 2*pi, 3*e)" => 2. * std::f64::consts::PI, + if_complex_blocks: "if(1, 2*sqrt(16) + sin(pi/2), 3*cos(0) + 4)" => 9., // Edge cases if_zero: "if(0.0, 1, 2)" => 2., // Complex nested expressions - if_nested_expr: "if((@sqrt(16) + 2) * (@sin(pi) + 1), 3 + 4 * 2, 5 - 2 / 1)" => 11., + if_nested_expr: "if((sqrt(16) + 2) * (sin(pi) + 1), 3 + 4 * 2, 5 - 2 / 1)" => 11., } } diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index 0b3fe7a196..225d0d7931 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -1,6 +1,6 @@ use crate::ast::{BinaryOp, Literal, Node, UnaryOp, Unit}; use crate::context::EvalContext; -use crate::lexer::{Span, Token, TokenStream, lexer}; +use crate::lexer::{Lexer, Span, Token}; use crate::value::{Complex, Number, Value}; use chumsky::container::Seq; use chumsky::input::{BorrowInput, ValueInput}; @@ -12,9 +12,6 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum ParseError<'src> { - #[error("lexical error(s): {0:#?}")] - Lex(Vec>), - #[error("syntax error(s): {0:#?}")] Parse(Vec, Span>>), } @@ -22,15 +19,10 @@ pub enum ParseError<'src> { impl Node { /// Lex + parse the source and either return an AST `Node` /// or a typed `ParseError`. - pub fn try_parse_from_str(src: &str) -> Result { - // ── stage 1: lexing ────────────────────────────────────────────── - let (tokens_opt, lex_errs) = lexer().parse(src).into_output_errors(); - if !lex_errs.is_empty() { - return Err(ParseError::Lex(lex_errs)); - } - let tokens = TokenStream::new(tokens_opt.expect("lexer always returns tokens with recovery")); + pub fn try_parse_from_str(src: &str) -> Result> { + let tokens = Lexer::new(src); - match parser().parse(tokens.map((0..src.len()).into(), |(t, s)| (t, s))).into_result() { + match parser().parse(tokens).into_result() { Ok(ast) => Ok(ast), Err(errs) => Err(ParseError::Parse(errs)), } @@ -63,12 +55,12 @@ where }) }); - let call = select! {Token::Call(s) => s} + let call = select! {Token::Ident(s) => s} .then(args) .try_map(|(name, args): (&str, Vec), span| Ok(Node::FnCall { name: name.to_string(), expr: args })); let parens = expr.clone().clone().delimited_by(just(Token::LParen), just(Token::RParen)); - let var = select! { Token::Var(name) => Node::Var(name.to_string()) }; + let var = select! { Token::Ident(name) => Node::Var(name.to_string()) }; let atom = choice((constant, if_expr, call, parens, var)).boxed(); @@ -84,7 +76,7 @@ where just(Token::EqEq).to(BinaryOp::Eq), )); - let unary = unary_op.repeated().foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }); + let unary = unary_op.repeated().foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }).boxed(); let cmp = unary.clone().foldl(cmp_op.then(unary).repeated(), |lhs: Node, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), @@ -107,7 +99,13 @@ where }) .boxed(); - product.clone().foldl(add_op.then(product).repeated(), |lhs, (op, rhs)| Node::BinOp { + let add = product.clone().foldl(add_op.then(product).repeated(), |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); + + add.clone().foldl(add.map(|rhs| (BinaryOp::Mul, rhs)).repeated(), |lhs, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), op, rhs: Box::new(rhs), @@ -133,7 +131,7 @@ mod tests { test_parser! { test_parse_int_literal: "42" => Node::Lit(Literal::Float(42.0)), test_parse_float_literal: "3.14" => Node::Lit(Literal::Float(#[allow(clippy::approx_constant)] 3.14)), - test_parse_ident: "#x" => Node::Var("x".to_string()), + test_parse_ident: "x" => Node::Var("x".to_string()), test_parse_unary_neg: "-42" => Node::UnaryOp { expr: Box::new(Node::Lit(Literal::Float(42.0))), op: UnaryOp::Neg, @@ -153,14 +151,19 @@ mod tests { op: BinaryOp::Pow, rhs: Box::new(Node::Lit(Literal::Float(3.0))), }, - test_parse_unary_sqrt: "@sqrt(16)" => Node::FnCall { + test_parse_unary_sqrt: "sqrt(16)" => Node::FnCall { name: "sqrt".to_string(), expr: vec![Node::Lit(Literal::Float(16.0))], }, - test_parse_i_call: "@i(16)" => Node::FnCall { - name:"i".to_string(), + test_parse_ii_call: "ii(16)" => Node::FnCall { + name:"ii".to_string(), expr: vec![Node::Lit(Literal::Float(16.0))] }, + test_parse_i_mul: "i(16)" => Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Complex(Complex::new(0.0, 1.0)))), + op: BinaryOp::Mul, + rhs: Box::new(Node::Lit(Literal::Float(16.0))), + }, test_parse_complex_expr: "(1 + 2) * 3 - 4 ^ 2" => Node::BinOp { lhs: Box::new(Node::BinOp { lhs: Box::new(Node::BinOp { @@ -178,7 +181,7 @@ mod tests { rhs: Box::new(Node::Lit(Literal::Float(2.0))), }), }, - test_conditional_expr: "if (#x+3, 0, 1)" => Node::Conditional{ + test_conditional_expr: "if (x+3, 0, 1)" => Node::Conditional{ condition: Box::new(Node::BinOp{ lhs: Box::new(Node::Var("x".to_string())), op: BinaryOp::Add, From e0563239b18bff489cfa7ad9528bdd4116b0c865 Mon Sep 17 00:00:00 2001 From: urisinger Date: Tue, 15 Jul 2025 13:19:19 +0300 Subject: [PATCH 13/16] fixed parser, add debugging for errors --- Cargo.lock | 14 ++- libraries/math-parser/Cargo.toml | 1 + libraries/math-parser/benches/bench.rs | 10 +- libraries/math-parser/src/diagnostic.rs | 159 ++++++++++++++++++++++++ libraries/math-parser/src/lexer.rs | 118 ++++++++++++++---- libraries/math-parser/src/lib.rs | 123 ++++++++++-------- libraries/math-parser/src/parser.rs | 92 ++++++++------ node-graph/gcore/src/ops.rs | 6 +- 8 files changed, 404 insertions(+), 119 deletions(-) create mode 100644 libraries/math-parser/src/diagnostic.rs diff --git a/Cargo.lock b/Cargo.lock index 463dd7e585..16845e0491 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -833,6 +833,17 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "codespan-reporting" +version = "0.12.0" +source = "git+https://github.com/urisinger/codespan-style-writer.git#164e6499a56ab10f5358ae1855be8266e210502c" +dependencies = [ + "lazy_static", + "serde", + "termcolor", + "unicode-width", +] + [[package]] name = "color" version = "0.1.0" @@ -3361,6 +3372,7 @@ name = "math-parser" version = "0.0.0" dependencies = [ "chumsky", + "codespan-reporting 0.12.0", "criterion", "lazy_static", "num-complex", @@ -3490,7 +3502,7 @@ dependencies = [ "bit-set", "bitflags 2.9.0", "cfg_aliases 0.1.1", - "codespan-reporting", + "codespan-reporting 0.11.1", "hexf-parse", "indexmap 2.7.1", "log", diff --git a/libraries/math-parser/Cargo.toml b/libraries/math-parser/Cargo.toml index efedc5ede0..1a25218792 100644 --- a/libraries/math-parser/Cargo.toml +++ b/libraries/math-parser/Cargo.toml @@ -12,6 +12,7 @@ thiserror = "2.0" lazy_static = "1.5" num-complex = "0.4" chumsky = { version = "0.10", default-features = false, features = ["std"] } +codespan-reporting = { git = "https://github.com/urisinger/codespan-style-writer.git"} [dev-dependencies] criterion = "0.5" diff --git a/libraries/math-parser/benches/bench.rs b/libraries/math-parser/benches/bench.rs index fc27144491..ac734b7f14 100644 --- a/libraries/math-parser/benches/bench.rs +++ b/libraries/math-parser/benches/bench.rs @@ -8,7 +8,7 @@ macro_rules! generate_benchmarks { $( c.bench_function(concat!("parse ", $input), |b| { b.iter(|| { - let _ = black_box(ast::Node::try_parse_from_str($input)).unwrap(); + let _ = black_box(ast::Node::try_parse_from_str($input)); }); }); )* @@ -16,7 +16,13 @@ macro_rules! generate_benchmarks { fn evaluation_bench(c: &mut Criterion) { $( - let expr = ast::Node::try_parse_from_str($input).unwrap(); + let expr = match ast::Node::try_parse_from_str($input) { + Ok(expr) => expr, + Err(err) => { + err.print(); + panic!(concat!("failed to parse `", $input, "`")); + } + }; let context = EvalContext::default(); c.bench_function(concat!("eval ", $input), |b| { diff --git a/libraries/math-parser/src/diagnostic.rs b/libraries/math-parser/src/diagnostic.rs new file mode 100644 index 0000000000..2d7d4874d6 --- /dev/null +++ b/libraries/math-parser/src/diagnostic.rs @@ -0,0 +1,159 @@ +use std::io::{self, Write}; + +use codespan_reporting::{ + diagnostic::{Diagnostic, Label, LabelStyle, Severity}, + files::{self, Files, SimpleFile, SimpleFiles}, + term::{ + self, Config, Renderer, RichDiagnostic, emit, + termcolor::{ColorChoice, StandardStream}, + }, +}; + +use crate::lexer::Span; + +pub struct CompileError { + pub file: SimpleFile, + pub diagnostics: Vec>, +} + +impl CompileError { + pub fn print(&self) { + let mut writer = StandardStream::stderr(ColorChoice::Auto); + let config = term::Config::default(); + for diag in &self.diagnostics { + term::emit(&mut writer.lock(), &config, &self.file, diag).unwrap(); + } + writer.flush(); + } + + pub fn render_html(&self, config: &Config) -> Result, files::Error> { + let mut buf = Vec::new(); + { + let mut html_writer = HtmlWriter::new(&mut buf); + + let mut renderer = Renderer::new(&mut html_writer, config); + for diag in &self.diagnostics { + RichDiagnostic::new(diag, config).render(&self.file, &mut renderer)?; + } + html_writer.close_span().expect("buffer writer cant fail"); + } + + Ok(buf) + } +} + +pub(crate) fn make_compile_error(filename: impl Into, src: &str, errs: impl IntoIterator)>) -> CompileError { + let file = SimpleFile::new(filename.into(), src.to_string()); + + let diagnostics = errs.into_iter().map(|(msg, primary, secondaries)| make_diagnostic(msg, primary, &secondaries)).collect(); + + CompileError { file, diagnostics } +} + +fn make_diagnostic(msg: impl Into, primary: Span, secondaries: &[(String, Span)]) -> Diagnostic<()> { + let msg_str = msg.into(); + let mut labels = vec![Label::primary((), primary).with_message(msg_str.clone())]; + for (smsg, span) in secondaries { + labels.push(Label::secondary((), *span).with_message(smsg.clone())); + } + Diagnostic::error().with_message(msg_str).with_labels(labels) +} + +struct HtmlWriter { + upstream: W, + span_open: bool, +} + +impl HtmlWriter { + pub fn new(upstream: W) -> Self { + HtmlWriter { upstream, span_open: false } + } + + /// Close any open span + fn close_span(&mut self) -> io::Result<()> { + if self.span_open { + write!(self.upstream, "")?; + self.span_open = false; + } + Ok(()) + } + + /// Open a new span with the given CSS class + fn open_span(&mut self, class: &str) -> io::Result<()> { + // close existing first + self.close_span()?; + write!(self.upstream, "", class)?; + self.span_open = true; + Ok(()) + } +} + +impl Write for HtmlWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut last = 0; + for (i, &b) in buf.iter().enumerate() { + let escape = match b { + b'<' => b"<"[..].as_ref(), + b'>' => b">"[..].as_ref(), + b'&' => b"&"[..].as_ref(), + _ => continue, + }; + self.upstream.write_all(&buf[last..i])?; + self.upstream.write_all(escape)?; + last = i + 1; + } + self.upstream.write_all(&buf[last..])?; + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + self.upstream.flush() + } +} + +impl codespan_reporting::term::WriteStyle for HtmlWriter { + fn set_header(&mut self, severity: Severity) -> io::Result<()> { + let class = match severity { + Severity::Bug => "header-bug", + Severity::Error => "header-error", + Severity::Warning => "header-warning", + Severity::Note => "header-note", + Severity::Help => "header-help", + }; + self.open_span(class) + } + + fn set_header_message(&mut self) -> io::Result<()> { + self.open_span("header-message") + } + + fn set_line_number(&mut self) -> io::Result<()> { + self.open_span("line-number") + } + + fn set_note_bullet(&mut self) -> io::Result<()> { + self.open_span("note-bullet") + } + + fn set_source_border(&mut self) -> io::Result<()> { + self.open_span("source-border") + } + + fn set_label(&mut self, severity: Severity, label_style: LabelStyle) -> io::Result<()> { + let sev = match severity { + Severity::Bug => "bug", + Severity::Error => "error", + Severity::Warning => "warning", + Severity::Note => "note", + Severity::Help => "help", + }; + let typ = match label_style { + LabelStyle::Primary => "primary", + LabelStyle::Secondary => "secondary", + }; + self.open_span(&format!("label-{}-{}", typ, sev)) + } + + fn reset(&mut self) -> io::Result<()> { + self.close_span() + } +} diff --git a/libraries/math-parser/src/lexer.rs b/libraries/math-parser/src/lexer.rs index 9bce2a5656..4b49a08515 100644 --- a/libraries/math-parser/src/lexer.rs +++ b/libraries/math-parser/src/lexer.rs @@ -1,4 +1,3 @@ -// ── lexer.rs ─────────────────────────────────────────────────────────── use crate::ast::Literal; use chumsky::input::{Input, ValueInput}; use chumsky::prelude::*; @@ -6,6 +5,7 @@ use chumsky::span::SimpleSpan; use chumsky::text::{ident, int}; use core::f64; use num_complex::Complex64; +use std::fmt; use std::iter::Peekable; use std::ops::Range; use std::str::Chars; @@ -14,10 +14,10 @@ pub type Span = SimpleSpan; #[derive(Clone, Debug, PartialEq)] pub enum Token<'src> { - // literals ---------------------------------------------------------------- - Const(Literal), // numeric or complex constants recognised at lex‑time + Float(f64), + Const(Constant), Ident(&'src str), - // punctuation ------------------------------------------------------------- + LParen, RParen, Comma, @@ -26,29 +26,97 @@ pub enum Token<'src> { Star, Slash, Caret, - // comparison -------------------------------------------------------------- + Lt, Le, Gt, Ge, EqEq, - // keywords ---------------------------------------------------------------- + If, } -fn const_lit(name: &str) -> Option { - use std::f64::consts::*; - - Some(match name { - "pi" | "π" => Literal::Float(PI), - "tau" | "τ" => Literal::Float(TAU), - "e" => Literal::Float(E), - "phi" | "φ" => Literal::Float(1.618_033_988_75), - "inf" | "∞" => Literal::Float(f64::INFINITY), - "i" => Literal::Complex(Complex64::new(0.0, 1.0)), - "G" => Literal::Float(9.80665), - _ => return None, - }) +impl<'src> fmt::Display for Token<'src> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Token::Float(x) => write!(f, "{x}"), + Token::Const(c) => write!(f, "{c}"), + Token::Ident(name) => write!(f, "{name}"), + + Token::LParen => f.write_str("("), + Token::RParen => f.write_str(")"), + Token::Comma => f.write_str(","), + Token::Plus => f.write_str("+"), + Token::Minus => f.write_str("-"), + Token::Star => f.write_str("*"), + Token::Slash => f.write_str("/"), + Token::Caret => f.write_str("^"), + + Token::Lt => f.write_str("<"), + Token::Le => f.write_str("<="), + Token::Gt => f.write_str(">"), + Token::Ge => f.write_str(">="), + Token::EqEq => f.write_str("=="), + + Token::If => f.write_str("if"), + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Constant { + Pi, + Tau, + E, + Phi, + Inf, + I, + G, +} + +impl Constant { + pub fn value(self) -> Literal { + use Constant::*; + use std::f64::consts; + match self { + Pi => Literal::Float(consts::PI), + Tau => Literal::Float(consts::TAU), + E => Literal::Float(consts::E), + Phi => Literal::Float(1.618_033_988_75), + Inf => Literal::Float(f64::INFINITY), + I => Literal::Complex(Complex64::new(0.0, 1.0)), + G => Literal::Float(9.80665), + } + } + + pub fn from_str(name: &str) -> Option { + use Constant::*; + Some(match name { + "pi" | "π" => Pi, + "tau" | "τ" => Tau, + "e" => E, + "phi" | "φ" => Phi, + "inf" | "∞" => Inf, + "i" => I, + "G" => G, + _ => return None, + }) + } +} + +impl fmt::Display for Constant { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use Constant::*; + f.write_str(match self { + Pi => "pi", + Tau => "tau", + E => "e", + Phi => "phi", + Inf => "inf", + I => "i", + G => "G", + }) + } } pub struct Lexer<'a> { @@ -182,7 +250,7 @@ impl<'a> Lexer<'a> { c if c.is_ascii_digit() || (c == '.' && self.peek().is_some_and(|c| c.is_ascii_digit())) => { self.pos = start; - Const(Literal::Float(self.lex_number()?)) + Float(self.lex_number()?) } _ => { @@ -191,7 +259,7 @@ impl<'a> Lexer<'a> { if ident == "if" { If - } else if let Some(lit) = const_lit(ident) { + } else if let Some(lit) = Constant::from_str(ident) { Const(lit) } else if ch.is_alphanumeric() { Ident(ident) @@ -205,6 +273,14 @@ impl<'a> Lexer<'a> { } } +impl<'a> Iterator for Lexer<'a> { + type Item = Token<'a>; + + fn next(&mut self) -> Option { + self.next_token() + } +} + impl<'src> Input<'src> for Lexer<'src> { type Token = Token<'src>; type Span = Span; diff --git a/libraries/math-parser/src/lib.rs b/libraries/math-parser/src/lib.rs index e360ef0cce..6c37ff78fb 100644 --- a/libraries/math-parser/src/lib.rs +++ b/libraries/math-parser/src/lib.rs @@ -3,6 +3,7 @@ pub mod ast; mod constants; pub mod context; +pub mod diagnostic; pub mod executer; pub mod lexer; pub mod parser; @@ -10,11 +11,11 @@ pub mod value; use ast::Unit; use context::{EvalContext, ValueMap}; +use diagnostic::CompileError; use executer::EvalError; -use parser::ParseError; use value::Value; -pub fn evaluate(expression: &str) -> Result, ParseError<'_>> { +pub fn evaluate(expression: &str) -> Result, CompileError> { let expr = ast::Node::try_parse_from_str(expression); let context = EvalContext::default(); expr.map(|node| node.eval(&context)) @@ -24,67 +25,79 @@ pub fn evaluate(expression: &str) -> Result, ParseError mod tests { use super::*; use ast::Unit; + use codespan_reporting::term::{ + self, + termcolor::{ColorChoice, StandardStream}, + }; use value::Number; const EPSILON: f64 = 1e-10_f64; - macro_rules! test_end_to_end{ - ($($name:ident: $input:expr_2021 => $expected_value:expr_2021),* $(,)?) => { - $( - #[test] - fn $name() { - let expected_value = $expected_value; - - let expr = ast::Node::try_parse_from_str($input); - let context = EvalContext::default(); - - dbg!(&expr); - - let actual_value = expr.map(|node| node.eval(&context)).unwrap(); - let actual_value = actual_value.unwrap(); - - - - let expected_value = expected_value.into(); - - match (actual_value, expected_value) { - (Value::Number(Number::Complex(actual_c)), Value::Number(Number::Complex(expected_c))) => { - assert!( - (actual_c.re.is_infinite() && expected_c.re.is_infinite()) || (actual_c.re - expected_c.re).abs() < EPSILON, - "Expected real part {}, but got {}", - expected_c.re, - actual_c.re - ); - assert!( - (actual_c.im.is_infinite() && expected_c.im.is_infinite()) || (actual_c.im - expected_c.im).abs() < EPSILON, - "Expected imaginary part {}, but got {}", - expected_c.im, - actual_c.im - ); - } - (Value::Number(Number::Real(actual_f)), Value::Number(Number::Real(expected_f))) => { - if actual_f.is_infinite() || expected_f.is_infinite() { - assert!( - actual_f.is_infinite() && expected_f.is_infinite() && actual_f == expected_f, - "Expected infinite value {}, but got {}", - expected_f, - actual_f - ); - } else if actual_f.is_nan() || expected_f.is_nan() { - assert!(actual_f.is_nan() && expected_f.is_nan(), "Expected NaN, but got {}", actual_f); - } else { - assert!((actual_f - expected_f).abs() < EPSILON, "Expected {}, but got {}", expected_f, actual_f); - } - } - // Handle mismatched types - _ => panic!("Mismatched types: expected {:?}, got {:?}", expected_value, actual_value), - } + fn run_end_to_end_test(input: &str, expected_value: Value) { + // parse + pretty‐print any parse errors + let expr = match ast::Node::try_parse_from_str(input) { + Ok(expr) => expr, + Err(err) => { + err.print(); + panic!("failed to parse `{input}`"); + } + }; + let context = EvalContext::default(); - } - )* + let actual_value = match expr.eval(&context) { + Ok(v) => v, + Err(err) => panic!("failed to evaluate {input} becuase of error {err}"), }; + + // compare + match (actual_value, expected_value) { + (Value::Number(Number::Complex(a)), Value::Number(Number::Complex(e))) => { + // real part + if a.re.is_infinite() || e.re.is_infinite() { + assert!(a.re == e.re, "`{}` → real part: expected {:?}, got {:?}", input, e.re, a.re); + } else { + assert!((a.re - e.re).abs() < EPSILON, "`{}` → real part: expected {}, got {}", input, e.re, a.re); + } + + // imag part + if a.im.is_infinite() || e.im.is_infinite() { + assert!(a.im == e.im, "`{}` → imag part: expected {:?}, got {:?}", input, e.im, a.im); + } else { + assert!((a.im - e.im).abs() < EPSILON, "`{}` → imag part: expected {}, got {}", input, e.im, a.im); + } + } + + (Value::Number(Number::Real(a)), Value::Number(Number::Real(e))) => { + if a.is_infinite() || e.is_infinite() { + // both must be infinite and equal (i.e. both +∞ or both −∞) + assert!(a == e, "`{input}` → expected infinite {e:?}, got {a:?}"); + } else if a.is_nan() || e.is_nan() { + // both must be NaN + assert!(a.is_nan() && e.is_nan(), "`{input}` → expected NaN, got {a:?}"); + } else { + let diff = (a - e).abs(); + assert!(diff < EPSILON, "`{input}` → expected {e}, got {a}, Δ={diff}"); + } + } + + (got, expect) => { + panic!("`{input}` → mismatched types: expected {expect:?}, got {got:?}"); + } + } } + macro_rules! test_end_to_end { + ($($name:ident: $input:expr => $expected:expr),* $(,)?) => { + $( + #[test] + fn $name() { + // note the `.into()` here, so we still accept + // any `T: Into` in our macro + run_end_to_end_test($input, $expected.into()); + } + )* + }; + } test_end_to_end! { // Basic arithmetic infix_addition: "5 + 5" => 10., diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index 225d0d7931..b636a519ee 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -1,5 +1,6 @@ use crate::ast::{BinaryOp, Literal, Node, UnaryOp, Unit}; use crate::context::EvalContext; +use crate::diagnostic::{CompileError, make_compile_error}; use crate::lexer::{Lexer, Span, Token}; use crate::value::{Complex, Number, Value}; use chumsky::container::Seq; @@ -10,21 +11,23 @@ use num_complex::ComplexFloat; use std::num::{ParseFloatError, ParseIntError}; use thiserror::Error; -#[derive(Error, Debug)] -pub enum ParseError<'src> { - #[error("syntax error(s): {0:#?}")] - Parse(Vec, Span>>), -} - impl Node { - /// Lex + parse the source and either return an AST `Node` - /// or a typed `ParseError`. - pub fn try_parse_from_str(src: &str) -> Result> { + pub fn try_parse_from_str(src: &str) -> Result { let tokens = Lexer::new(src); match parser().parse(tokens).into_result() { Ok(ast) => Ok(ast), - Err(errs) => Err(ParseError::Parse(errs)), + Err(parse_errs) => { + let errs = parse_errs.into_iter().map(|e| { + let primary = e.span(); + let mut secondary = Vec::new(); + for (msg, ctx_span) in e.contexts() { + secondary.push((msg.to_string(), *ctx_span)); + } + (e.to_string(), *primary, secondary) + }); + Err(make_compile_error("expression", src, errs)) + } } } } @@ -34,35 +37,39 @@ where I: ValueInput<'src, Token = Token<'src>, Span = Span>, { recursive(|expr| { - let constant = select! {Token::Const(x) => Node::Lit(x)}; + let constant = select! { + Token::Float(f) => Node::Lit(Literal::Float(f)), + Token::Const(c) => Node::Lit(c.value()) + }; let args = expr.clone().separated_by(just(Token::Comma)).collect::>().delimited_by(just(Token::LParen), just(Token::RParen)); let if_expr = just(Token::If) - .ignore_then(args.clone()) // Parses (cond, a, b) - .try_map(|args: Vec, span| { - if args.len() != 3 { - return Err(Rich::custom(span, "Expected 3 arguments in if(cond, a, b)")); - } - let mut iter = args.into_iter(); - let cond = iter.next().unwrap(); - let if_b = iter.next().unwrap(); - let else_b = iter.next().unwrap(); - Ok(Node::Conditional { - condition: Box::new(cond), - if_block: Box::new(if_b), - else_block: Box::new(else_b), - }) - }); - - let call = select! {Token::Ident(s) => s} - .then(args) - .try_map(|(name, args): (&str, Vec), span| Ok(Node::FnCall { name: name.to_string(), expr: args })); - - let parens = expr.clone().clone().delimited_by(just(Token::LParen), just(Token::RParen)); - let var = select! { Token::Ident(name) => Node::Var(name.to_string()) }; - - let atom = choice((constant, if_expr, call, parens, var)).boxed(); + .ignore_then(args.clone()) // Parses (cond, a, b) + .try_map(|args: Vec, span| { + if args.len() != 3 { + return Err(Rich::custom(span, "Expected 3 arguments in if(cond, a, b)")); + } + let mut iter = args.into_iter(); + let cond = iter.next().unwrap(); + let if_b = iter.next().unwrap(); + let else_b = iter.next().unwrap(); + Ok(Node::Conditional { + condition: Box::new(cond), + if_block: Box::new(if_b), + else_block: Box::new(else_b), + }) + } + ); + + let ident = select! {Token::Ident(s) => s}.labelled("ident"); + + let call = ident.then(args).map(|(name, args): (&str, Vec)| Node::FnCall { name: name.to_string(), expr: args }); + + let parens = expr.clone().delimited_by(just(Token::LParen), just(Token::RParen)); + let var = ident.map(|s| Node::Var(s.to_string())); + + let atom = choice((constant, if_expr, call, parens, var)).labelled("atom").boxed(); let add_op = choice((just(Token::Plus).to(BinaryOp::Add), just(Token::Minus).to(BinaryOp::Sub))); let mul_op = choice((just(Token::Star).to(BinaryOp::Mul), just(Token::Slash).to(BinaryOp::Div))); @@ -78,7 +85,7 @@ where let unary = unary_op.repeated().foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }).boxed(); - let cmp = unary.clone().foldl(cmp_op.then(unary).repeated(), |lhs: Node, (op, rhs)| Node::BinOp { + let cmp = unary.clone().clone().foldl(cmp_op.then(unary).repeated(), |lhs: Node, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), op, rhs: Box::new(rhs), @@ -105,9 +112,9 @@ where rhs: Box::new(rhs), }); - add.clone().foldl(add.map(|rhs| (BinaryOp::Mul, rhs)).repeated(), |lhs, (op, rhs)| Node::BinOp { + add.clone().foldl(add.repeated(), |lhs, rhs| Node::BinOp { lhs: Box::new(lhs), - op, + op: BinaryOp::Mul, rhs: Box::new(rhs), }) }) @@ -121,7 +128,14 @@ mod tests { $( #[test] fn $name() { - let result = Node::try_parse_from_str($input).unwrap(); + + let result = match Node::try_parse_from_str($input){ + Ok(expr) => expr, + Err(err) => { + err.print(); + panic!(concat!("failed to parse `", $input, "`")); + } + }; assert_eq!(result, $expected); } )* diff --git a/node-graph/gcore/src/ops.rs b/node-graph/gcore/src/ops.rs index 3ea14373f0..8120e23831 100644 --- a/node-graph/gcore/src/ops.rs +++ b/node-graph/gcore/src/ops.rs @@ -54,7 +54,11 @@ fn math( let node = match ast::Node::try_parse_from_str(&expression) { Ok(expr) => expr, Err(e) => { - warn!("Invalid expression: `{expression}`\n{e:?}"); + warn!("Invalid expression: `{expression}`\n"); + // print colored error + e.print(); + // render error as html + _ = e.render_html(&Default::default()); return U::from(0.).unwrap(); } }; From edcc54b6dab24e83927beff2f628a7e681b44ea9 Mon Sep 17 00:00:00 2001 From: uri singer Date: Mon, 9 Feb 2026 01:36:18 +0200 Subject: [PATCH 14/16] fix parsing bugs and add tests --- libraries/math-parser/src/ast.rs | 2 ++ libraries/math-parser/src/lexer.rs | 18 +++++++++++++++ libraries/math-parser/src/lib.rs | 30 +++++++++++++++++------- libraries/math-parser/src/parser.rs | 36 ++++++++++++++++------------- libraries/math-parser/src/value.rs | 10 ++++++++ 5 files changed, 72 insertions(+), 24 deletions(-) diff --git a/libraries/math-parser/src/ast.rs b/libraries/math-parser/src/ast.rs index 051db8466a..0aea8f3ad2 100644 --- a/libraries/math-parser/src/ast.rs +++ b/libraries/math-parser/src/ast.rs @@ -55,11 +55,13 @@ pub enum BinaryOp { Sub, Mul, Div, + Modulo, Pow, Leq, Lt, Geq, Gt, + Neq, Eq, } diff --git a/libraries/math-parser/src/lexer.rs b/libraries/math-parser/src/lexer.rs index 4b49a08515..dd21df3dfc 100644 --- a/libraries/math-parser/src/lexer.rs +++ b/libraries/math-parser/src/lexer.rs @@ -23,6 +23,7 @@ pub enum Token<'src> { Comma, Plus, Minus, + Modulo, Star, Slash, Caret, @@ -31,6 +32,7 @@ pub enum Token<'src> { Le, Gt, Ge, + Neq, EqEq, If, @@ -48,6 +50,7 @@ impl<'src> fmt::Display for Token<'src> { Token::Comma => f.write_str(","), Token::Plus => f.write_str("+"), Token::Minus => f.write_str("-"), + Token::Modulo => f.write_str("%"), Token::Star => f.write_str("*"), Token::Slash => f.write_str("/"), Token::Caret => f.write_str("^"), @@ -56,6 +59,7 @@ impl<'src> fmt::Display for Token<'src> { Token::Le => f.write_str("<="), Token::Gt => f.write_str(">"), Token::Ge => f.write_str(">="), + Token::Neq => f.write_str("!="), Token::EqEq => f.write_str("=="), Token::If => f.write_str("if"), @@ -220,9 +224,21 @@ impl<'a> Lexer<'a> { '+' => Plus, '-' => Minus, '*' => Star, + '%' => Modulo, '/' => Slash, '^' => Caret, + '≠' => Neq, + '!' => { + if self.peek() == Some('=') { + self.bump(); + Neq + } else { + return None; + } + } + + '≤' => Le, '<' => { if self.peek() == Some('=') { self.bump(); @@ -231,6 +247,8 @@ impl<'a> Lexer<'a> { Lt } } + + '≥' => Ge, '>' => { if self.peek() == Some('=') { self.bump(); diff --git a/libraries/math-parser/src/lib.rs b/libraries/math-parser/src/lib.rs index 6c37ff78fb..dd33b2be8b 100644 --- a/libraries/math-parser/src/lib.rs +++ b/libraries/math-parser/src/lib.rs @@ -34,7 +34,6 @@ mod tests { const EPSILON: f64 = 1e-10_f64; fn run_end_to_end_test(input: &str, expected_value: Value) { - // parse + pretty‐print any parse errors let expr = match ast::Node::try_parse_from_str(input) { Ok(expr) => expr, Err(err) => { @@ -42,6 +41,7 @@ mod tests { panic!("failed to parse `{input}`"); } }; + dbg!(&expr); let context = EvalContext::default(); let actual_value = match expr.eval(&context) { @@ -91,9 +91,7 @@ mod tests { $( #[test] fn $name() { - // note the `.into()` here, so we still accept - // any `T: Into` in our macro - run_end_to_end_test($input, $expected.into()); + run_end_to_end_test($input, ($expected).into()); } )* }; @@ -101,9 +99,17 @@ mod tests { test_end_to_end! { // Basic arithmetic infix_addition: "5 + 5" => 10., - infix_subtraction_units: "5 - 3" => 2., - infix_multiplication_units: "4 * 4" => 16., - infix_division_units: "8/2" => 4., + infix_subtraction: "5 - 3" => 2., + infix_multiplication: "4 * 4" => 16., + infix_division: "8/2" => 4., + modulo_pos_pos: "3.2 % 2" => 1.2, + modulo_pos_neg: "3.2 % -2" => 1.2, + modulo_neg_neg: "(-3.2) % -2" => -1.2, + modulo_neg_pos: "(-3.2) % 2" => -1.2, + exp_pos_pos: "3.2 ^ 2" => 256. / 25., + exp_pos_neg: "3.2 ^ -2" => 25. / 256., + exp_neg_neg: "-3.2 ^ -2" => -25. / 256., + exp_neg_pos: "-3.2 ^ 2" => -256. / 25., // Order of operations order_of_operations_negative_prefix: "-10 + 5" => -5., @@ -168,7 +174,8 @@ mod tests { // Nested arithmetic if_complex_arithmetic: "if((5+3)*(2-1), 10, 20)" => 10., - if_with_division: "if(8/4-2 == 0,15, 25)" => 15., + if_with_division: "if(8/4-2 == 0, 15, 25)" => 15., + if_with_division_ne: "if(8/4-2 ≠ 0, 15, 25)" => 25., // Constants in conditions if_with_pi: "if(pi > 3, 1, 0)" => 1., @@ -186,6 +193,13 @@ mod tests { if_complex_condition: "if(sqrt(16) + sin(pi) < 5, 2*pi, 3*e)" => 2. * std::f64::consts::PI, if_complex_blocks: "if(1, 2*sqrt(16) + sin(pi/2), 3*cos(0) + 4)" => 9., + // TODO: Combine into one test with && + le: "if(1 <= 2, 1., 0.)" => 1., + le_special: "if(1 ≤ 2, 1., 0.)" => 1., + + ge: "if(2 >= 1, 1, 0)" => 1., + ge_special: "if(2 ≥ 1, 1, 0)" => 1., + // Edge cases if_zero: "if(0.0, 1, 2)" => 2., diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index b636a519ee..2d5e9f59cc 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -72,7 +72,7 @@ where let atom = choice((constant, if_expr, call, parens, var)).labelled("atom").boxed(); let add_op = choice((just(Token::Plus).to(BinaryOp::Add), just(Token::Minus).to(BinaryOp::Sub))); - let mul_op = choice((just(Token::Star).to(BinaryOp::Mul), just(Token::Slash).to(BinaryOp::Div))); + let mul_op = choice((just(Token::Star).to(BinaryOp::Mul), just(Token::Slash).to(BinaryOp::Div), just(Token::Modulo).to(BinaryOp::Modulo))); let pow_op = just(Token::Caret).to(BinaryOp::Pow); let unary_op = just(Token::Minus).to(UnaryOp::Neg); let cmp_op = choice(( @@ -80,26 +80,24 @@ where just(Token::Le).to(BinaryOp::Leq), just(Token::Gt).to(BinaryOp::Gt), just(Token::Ge).to(BinaryOp::Geq), + just(Token::Neq).to(BinaryOp::Neq), just(Token::EqEq).to(BinaryOp::Eq), )); - let unary = unary_op.repeated().foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }).boxed(); - - let cmp = unary.clone().clone().foldl(cmp_op.then(unary).repeated(), |lhs: Node, (op, rhs)| Node::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - }); + let pow = atom.clone().foldl( + pow_op.then(unary_op.clone().repeated().foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }).boxed()).repeated(), + |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }, + ); - let pow = cmp.clone().foldl(pow_op.then(cmp).repeated(), |lhs, (op, rhs)| Node::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - }); + let unary = unary_op.repeated().foldr(pow, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }).boxed(); - let product = pow + let product = unary .clone() - .foldl(mul_op.then(pow).repeated(), |lhs, (op, rhs)| Node::BinOp { + .foldl(mul_op.then(unary).repeated(), |lhs, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), op, rhs: Box::new(rhs), @@ -112,7 +110,13 @@ where rhs: Box::new(rhs), }); - add.clone().foldl(add.repeated(), |lhs, rhs| Node::BinOp { + let cmp = add.clone().foldl(cmp_op.then(add).repeated(), |lhs: Node, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); + + cmp.clone().foldl(cmp.repeated(), |lhs, rhs| Node::BinOp { lhs: Box::new(lhs), op: BinaryOp::Mul, rhs: Box::new(rhs), diff --git a/libraries/math-parser/src/value.rs b/libraries/math-parser/src/value.rs index c4e5217ba8..c7792f9819 100644 --- a/libraries/math-parser/src/value.rs +++ b/libraries/math-parser/src/value.rs @@ -60,11 +60,13 @@ impl Number { BinaryOp::Sub => lhs - rhs, BinaryOp::Mul => lhs * rhs, BinaryOp::Div => lhs / rhs, + BinaryOp::Modulo => lhs % rhs, BinaryOp::Pow => lhs.powf(rhs), BinaryOp::Leq => (lhs <= rhs) as u8 as f64, BinaryOp::Lt => (lhs < rhs) as u8 as f64, BinaryOp::Geq => (lhs >= rhs) as u8 as f64, BinaryOp::Gt => (lhs > rhs) as u8 as f64, + BinaryOp::Neq => (lhs != rhs) as u8 as f64, BinaryOp::Eq => (lhs == rhs) as u8 as f64, }; @@ -77,10 +79,18 @@ impl Number { BinaryOp::Sub => lhs - rhs, BinaryOp::Mul => lhs * rhs, BinaryOp::Div => lhs / rhs, + BinaryOp::Modulo => lhs % rhs, BinaryOp::Pow => lhs.powc(rhs), BinaryOp::Leq | BinaryOp::Lt | BinaryOp::Geq | BinaryOp::Gt => { return None; } + BinaryOp::Neq => { + if lhs != rhs { + return Some(Number::Real(1.0)); + } else { + return Some(Number::Real(0.0)); + } + } BinaryOp::Eq => { if lhs == rhs { return Some(Number::Real(1.0)); From 5001ee568ae667a9fd4579d5ec2626064f71911a Mon Sep 17 00:00:00 2001 From: uri singer Date: Fri, 13 Mar 2026 11:50:45 +0200 Subject: [PATCH 15/16] add all functions, add tests --- libraries/math-parser/src/ast.rs | 5 + libraries/math-parser/src/constants.rs | 290 ++++++++++++++++++++++++- libraries/math-parser/src/executer.rs | 1 + libraries/math-parser/src/lexer.rs | 27 ++- libraries/math-parser/src/lib.rs | 87 +++++++- libraries/math-parser/src/parser.rs | 55 ++++- libraries/math-parser/src/value.rs | 50 ++++- 7 files changed, 496 insertions(+), 19 deletions(-) diff --git a/libraries/math-parser/src/ast.rs b/libraries/math-parser/src/ast.rs index 0aea8f3ad2..4b4c0be625 100644 --- a/libraries/math-parser/src/ast.rs +++ b/libraries/math-parser/src/ast.rs @@ -54,7 +54,11 @@ pub enum BinaryOp { Add, Sub, Mul, + /// Logical AND (nonzero treated as true, returns 1.0 or 0.0) + And, Div, + /// Logical OR (nonzero treated as true, returns 1.0 or 0.0) + Or, Modulo, Pow, Leq, @@ -70,6 +74,7 @@ pub enum UnaryOp { Neg, Sqrt, Fac, + Not, } #[derive(Debug, PartialEq)] diff --git a/libraries/math-parser/src/constants.rs b/libraries/math-parser/src/constants.rs index cdf3b2439e..d2ea380f3f 100644 --- a/libraries/math-parser/src/constants.rs +++ b/libraries/math-parser/src/constants.rs @@ -71,6 +71,7 @@ lazy_static! { }), ); + // Inverse trig with legacy names and standard aliases map.insert( "invsin", Box::new(|values| match values { @@ -79,6 +80,14 @@ lazy_static! { _ => None, }), ); + map.insert( + "asin", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.asin()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.asin()))), + _ => None, + }), + ); map.insert( "invcos", @@ -88,6 +97,14 @@ lazy_static! { _ => None, }), ); + map.insert( + "acos", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.acos()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.acos()))), + _ => None, + }), + ); map.insert( "invtan", @@ -97,6 +114,14 @@ lazy_static! { _ => None, }), ); + map.insert( + "atan", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.atan()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.atan()))), + _ => None, + }), + ); map.insert( "invcsc", @@ -106,6 +131,14 @@ lazy_static! { _ => None, }), ); + map.insert( + "acsc", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.recip().asin()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.recip().asin()))), + _ => None, + }), + ); map.insert( "invsec", @@ -115,12 +148,36 @@ lazy_static! { _ => None, }), ); + map.insert( + "asec", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.recip().acos()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.recip().acos()))), + _ => None, + }), + ); map.insert( "invcot", Box::new(|values| match values { - [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real((PI / 2.0 - real).atan()))), - [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex((Complex::new(PI / 2.0, 0.0) - complex).atan()))), + [Value::Number(Number::Real(real))] => { + Some(Value::Number(Number::Real(real.recip().atan()))) + } + [Value::Number(Number::Complex(complex))] => { + Some(Value::Number(Number::Complex(complex.recip().atan()))) + } + _ => None, + }), + ); + map.insert( + "acot", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => { + Some(Value::Number(Number::Real(real.recip().atan()))) + } + [Value::Number(Number::Complex(complex))] => { + Some(Value::Number(Number::Complex(complex.recip().atan()))) + } _ => None, }), ); @@ -152,6 +209,34 @@ lazy_static! { }), ); + // Reciprocal hyperbolic functions + map.insert( + "csch", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sinh().recip()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sinh().recip()))), + _ => None, + }), + ); + + map.insert( + "sech", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.cosh().recip()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.cosh().recip()))), + _ => None, + }), + ); + + map.insert( + "coth", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.tanh().recip()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.tanh().recip()))), + _ => None, + }), + ); + // Inverse Hyperbolic Functions map.insert( "asinh", @@ -180,6 +265,34 @@ lazy_static! { }), ); + // Inverse reciprocal hyperbolic functions + map.insert( + "acsch", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.recip().asinh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.recip().asinh()))), + _ => None, + }), + ); + + map.insert( + "asech", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.recip().acosh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.recip().acosh()))), + _ => None, + }), + ); + + map.insert( + "acoth", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.recip().atanh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.recip().atanh()))), + _ => None, + }), + ); + // Logarithm Functions map.insert( "ln", @@ -190,6 +303,45 @@ lazy_static! { }), ); + // Exponential / power helpers + map.insert( + "exp", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.exp()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.exp()))), + _ => None, + }), + ); + + map.insert( + "pow", + Box::new(|values| match values { + [Value::Number(Number::Real(x)), Value::Number(Number::Real(n))] => { + Some(Value::Number(Number::Real(x.powf(*n)))) + } + [Value::Number(Number::Complex(x)), Value::Number(Number::Real(n))] => { + Some(Value::Number(Number::Complex(x.powf(*n)))) + } + [Value::Number(Number::Complex(x)), Value::Number(Number::Complex(n))] => { + Some(Value::Number(Number::Complex(x.powc(*n)))) + } + _ => None, + }), + ); + + map.insert( + "root", + Box::new(|values| match values { + [Value::Number(Number::Real(x)), Value::Number(Number::Real(n))] => { + Some(Value::Number(Number::Real(x.powf(1.0 / *n)))) + } + [Value::Number(Number::Complex(x)), Value::Number(Number::Real(n))] => { + Some(Value::Number(Number::Complex(x.powf(1.0 / *n)))) + } + _ => None, + }), + ); + map.insert( "log", Box::new(|values| match values { @@ -246,6 +398,16 @@ lazy_static! { }), ); + map.insert( + "atan2", + Box::new(|values| match values { + [Value::Number(Number::Real(y)), Value::Number(Number::Real(x))] => { + Some(Value::Number(Number::Real(y.atan2(*x)))) + } + _ => None, + }), + ); + // Mapping Functions map.insert( "abs", @@ -300,6 +462,105 @@ lazy_static! { }), ); + map.insert( + "remap", + Box::new(|values| match values { + [ + Value::Number(Number::Real(value)), + Value::Number(Number::Real(in_a)), + Value::Number(Number::Real(in_b)), + Value::Number(Number::Real(out_a)), + Value::Number(Number::Real(out_b)), + ] => { + let t = (*value - *in_a) / (*in_b - *in_a); + Some(Value::Number(Number::Real(out_a + t * (out_b - out_a)))) + } + _ => None, + }), + ); + + map.insert( + "trunc", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.trunc()))), + _ => None, + }), + ); + + map.insert( + "fract", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.fract()))), + _ => None, + }), + ); + + map.insert( + "sign", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => { + let s = if *real > 0.0 { + 1.0 + } else if *real < 0.0 { + -1.0 + } else { + 0.0 + }; + Some(Value::Number(Number::Real(s))) + } + _ => None, + }), + ); + + map.insert( + "gcd", + Box::new(|values| match values { + [Value::Number(Number::Real(a)), Value::Number(Number::Real(b))] => { + let mut x = a.trunc() as i64; + let mut y = b.trunc() as i64; + if x == 0 && y == 0 { + return Some(Value::Number(Number::Real(0.0))); + } + x = x.abs(); + y = y.abs(); + while y != 0 { + let r = x % y; + x = y; + y = r; + } + Some(Value::Number(Number::Real(x as f64))) + } + _ => None, + }), + ); + + map.insert( + "lcm", + Box::new(|values| match values { + [Value::Number(Number::Real(a)), Value::Number(Number::Real(b))] => { + let mut x = a.trunc() as i64; + let mut y = b.trunc() as i64; + x = x.abs(); + y = y.abs(); + if x == 0 || y == 0 { + return Some(Value::Number(Number::Real(0.0))); + } + + // gcd + let mut gx = x; + let mut gy = y; + while gy != 0 { + let r = gx % gy; + gx = gy; + gy = r; + } + let lcm = (x / gx) * y; + Some(Value::Number(Number::Real(lcm as f64))) + } + _ => None, + }), + ); + // Complex Number Functions map.insert( "real", @@ -319,6 +580,31 @@ lazy_static! { }), ); + map.insert( + "conj", + Box::new(|values| match values { + [Value::Number(Number::Complex(complex))] => { + Some(Value::Number(Number::Complex(complex.conj()))) + } + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(*real))), + _ => None, + }), + ); + + map.insert( + "arg", + Box::new(|values| match values { + [Value::Number(Number::Complex(complex))] => { + Some(Value::Number(Number::Real(complex.arg()))) + } + [Value::Number(Number::Real(real))] => { + let angle = if *real >= 0.0 { 0.0 } else { PI }; + Some(Value::Number(Number::Real(angle))) + } + _ => None, + }), + ); + // Logical Functions map.insert( "isnan", diff --git a/libraries/math-parser/src/executer.rs b/libraries/math-parser/src/executer.rs index 85871d769c..f795735093 100644 --- a/libraries/math-parser/src/executer.rs +++ b/libraries/math-parser/src/executer.rs @@ -33,6 +33,7 @@ impl Node { Node::Var(name) => context.get_value(name).ok_or_else(|| EvalError::MissingValue(name.clone())), Node::FnCall { name, expr } => { let values = expr.iter().map(|expr| expr.eval(context)).collect::, EvalError>>()?; + if let Some(function) = DEFAULT_FUNCTIONS.get(&name.as_str()) { function(&values).ok_or(EvalError::TypeError) } else if let Some(val) = context.run_function(name, &values) { diff --git a/libraries/math-parser/src/lexer.rs b/libraries/math-parser/src/lexer.rs index dd21df3dfc..4b1e4c7a1b 100644 --- a/libraries/math-parser/src/lexer.rs +++ b/libraries/math-parser/src/lexer.rs @@ -18,6 +18,10 @@ pub enum Token<'src> { Const(Constant), Ident(&'src str), + AndAnd, + OrOr, + Bang, + LParen, RParen, Comma, @@ -45,6 +49,10 @@ impl<'src> fmt::Display for Token<'src> { Token::Const(c) => write!(f, "{c}"), Token::Ident(name) => write!(f, "{name}"), + Token::AndAnd => f.write_str("&&"), + Token::OrOr => f.write_str("||"), + Token::Bang => f.write_str("!"), + Token::LParen => f.write_str("("), Token::RParen => f.write_str(")"), Token::Comma => f.write_str(","), @@ -218,6 +226,23 @@ impl<'a> Lexer<'a> { use Token::*; let tok = match ch { + '&' => { + if self.peek() == Some('&') { + self.bump(); + AndAnd + } else { + return None; + } + } + '|' => { + if self.peek() == Some('|') { + self.bump(); + OrOr + } else { + return None; + } + } + '(' => LParen, ')' => RParen, ',' => Comma, @@ -234,7 +259,7 @@ impl<'a> Lexer<'a> { self.bump(); Neq } else { - return None; + Bang } } diff --git a/libraries/math-parser/src/lib.rs b/libraries/math-parser/src/lib.rs index dd33b2be8b..6889e5b27d 100644 --- a/libraries/math-parser/src/lib.rs +++ b/libraries/math-parser/src/lib.rs @@ -137,6 +137,12 @@ mod tests { exponent_mixed_operations: "2^3 + 4^2" => 24., exponent_nested: "2^(3+1)" => 16., + // Factorial (postfix !) + factorial_simple: "5!" => 120., + factorial_nested: "(3 + 2)!" => 120., + factorial_zero: "0!" => 1., + factorial_chain: "3!!" => 720., // (3!)! = 6! = 720 + // Operations with negative values negative_nested_parentheses: "-(5 + 3 * (2 - 1))" => -8., negative_sqrt_addition: "-(sqrt(16) + sqrt(9))" => -7., @@ -149,8 +155,7 @@ mod tests { constant_e: "e" => std::f64::consts::E, constant_phi: "phi" => 1.61803398875, constant_tau: "tau" => 2.0 * std::f64::consts::PI, - constant_infinity: "inf" => f64::INFINITY, - constant_infinity_symbol: "∞" => f64::INFINITY, + constant_infinity: "if(inf == ∞, inf, 0)" => f64::INFINITY, multiply_pi: "2 * pi" => 2.0 * std::f64::consts::PI, add_e_constant: "e + 1" => std::f64::consts::E + 1.0, multiply_phi_constant: "phi * 2" => 1.61803398875 * 2.0, @@ -163,6 +168,33 @@ mod tests { trig_tan_pi_div_four: "tan(pi/4)" => 1.0, trig_sin_tau: "sin(tau)" => 0.0, trig_cos_tau_div_two: "cos(tau/2)" => -1.0, + trig_csc: "csc(pi/2)" => 1.0, + trig_sec: "sec(0)" => 1.0, + trig_cot: "cot(pi/4)" => 1.0, + + // Inverse trig aliases + inverse_trig_asin: "asin(1)" => std::f64::consts::FRAC_PI_2, + inverse_trig_acos: "acos(1)" => 0.0, + inverse_trig_atan: "atan(1)" => std::f64::consts::FRAC_PI_4, + inverse_trig_acsc: "acsc(1)" => std::f64::consts::FRAC_PI_2, + inverse_trig_asec: "asec(1)" => 0.0, + inverse_trig_acot: "acot(1)" => std::f64::consts::FRAC_PI_4, + + // Hyperbolic and reciprocal hyperbolic + hyperbolic_sinh: "sinh(0)" => 0.0, + hyperbolic_cosh: "cosh(0)" => 1.0, + hyperbolic_tanh: "tanh(0)" => 0.0, + hyperbolic_csch: "csch(1)" => 1f64.sinh().recip(), + hyperbolic_sech: "sech(0)" => 1.0, + hyperbolic_coth: "coth(1)" => 1f64.tanh().recip(), + + // Inverse hyperbolic + inverse_hyperbolic_asinh: "asinh(0)" => 0.0, + inverse_hyperbolic_acosh: "acosh(1)" => 0.0, + inverse_hyperbolic_atanh: "atanh(0)" => 0.0, + inverse_hyperbolic_acsch: "acsch(1)" => 1.0f64.asinh(), + inverse_hyperbolic_asech: "asech(1)" => 1.0f64.acosh(), + inverse_hyperbolic_acoth: "acoth(2)" => 0.5f64.atanh(), // Basic if statements if_true_condition: "if(1,5,3)" => 5., @@ -185,6 +217,26 @@ mod tests { if_with_sqrt: "if(sqrt(16) == 4, 1, 0)" => 1., if_with_sin: "if(sin(pi) == 0.0, 1, 0)" => 0., + // Logical NOT (prefix !) + logical_not_zero: "!0" => 1.0, + logical_not_nonzero: "!5" => 0.0, + logical_not_expression: "!(2 - 2)" => 1.0, + + // Logical helpers as functions + logical_isnan: "isnan(0/0)" => 1.0, + logical_eq: "eq(2, 2)" => 1.0, + logical_greater: "greater(3, 2)" => 1.0, + + // Log / exp / pow / root + log_ln: "ln(e)" => 1.0, + log_log10: "log(100)" => 2.0, + log_log2: "log2(8)" => 3.0, + log_change_of_base: "log(8, 2)" => 3.0, + exp_function: "exp(1)" => std::f64::consts::E, + pow_real: "pow(2, 3)" => 8.0, + root_square: "root(9, 2)" => 3.0, + root_cube: "root(8, 3)" => 2.0, + // Nested if statements nested_if: "if(1, if(0, 1, 2), 3)" => 2., nested_if_complex: "if(2-2 == 0, if(1, 5, 6), if(1, 7, 8))" => 5., @@ -193,12 +245,33 @@ mod tests { if_complex_condition: "if(sqrt(16) + sin(pi) < 5, 2*pi, 3*e)" => 2. * std::f64::consts::PI, if_complex_blocks: "if(1, 2*sqrt(16) + sin(pi/2), 3*cos(0) + 4)" => 9., - // TODO: Combine into one test with && - le: "if(1 <= 2, 1., 0.)" => 1., - le_special: "if(1 ≤ 2, 1., 0.)" => 1., + // Mapping helpers + mapping_trunc: "trunc(3.7)" => 3.0, + mapping_fract: "fract(3.25)" => 0.25, + mapping_sign_pos: "sign(5)" => 1.0, + mapping_sign_neg: "sign(-5)" => -1.0, + + // Geometry / mapping extras + geometry_hypot: "hypot(3, 4)" => 5.0, + mapping_remap: "remap(5, 0, 10, 0, 100)" => 50.0, + + // GCD / LCM + gcd_simple: "gcd(24, 18)" => 6.0, + lcm_simple: "lcm(4, 6)" => 12.0, + + // atan2 + trig_atan2_axis: "atan2(1, 0)" => std::f64::consts::FRAC_PI_2, + + // Comparison operators combined with logical AND + comparison_operators: "if(1 <= 2 && 1 ≤ 2 && 2 >= 1 && 2 ≥ 1, 1., 0.)" => 1., - ge: "if(2 >= 1, 1, 0)" => 1., - ge_special: "if(2 ≥ 1, 1, 0)" => 1., + // Logical AND / OR + logical_and_true: "if(1 <= 2 && 2 < 3, 1., 0.)" => 1., + logical_and_false: "if(1 <= 2 && 3 < 2, 1., 0.)" => 0., + logical_or_true_left: "if(1 > 2 || 2 < 3, 1., 0.)" => 1., + logical_or_true_right: "if(2 < 1 || 2 < 3, 1., 0.)" => 1., + logical_or_false: "if(1 > 2 || 3 < 2, 1., 0.)" => 0., + logical_precedence_and_over_or: "if(0 == 1 || 1 == 1 && 0 == 0, 1., 0.)" => 1., // Edge cases if_zero: "if(0.0, 1, 2)" => 2., diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index 2d5e9f59cc..fc88e6aa5c 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -74,7 +74,12 @@ where let add_op = choice((just(Token::Plus).to(BinaryOp::Add), just(Token::Minus).to(BinaryOp::Sub))); let mul_op = choice((just(Token::Star).to(BinaryOp::Mul), just(Token::Slash).to(BinaryOp::Div), just(Token::Modulo).to(BinaryOp::Modulo))); let pow_op = just(Token::Caret).to(BinaryOp::Pow); - let unary_op = just(Token::Minus).to(UnaryOp::Neg); + let unary_op = choice(( + just(Token::Minus).to(UnaryOp::Neg), + just(Token::Bang).to(UnaryOp::Not), + )); + let and_op = just(Token::AndAnd).to(BinaryOp::And); + let or_op = just(Token::OrOr).to(BinaryOp::Or); let cmp_op = choice(( just(Token::Lt).to(BinaryOp::Lt), just(Token::Le).to(BinaryOp::Leq), @@ -84,8 +89,28 @@ where just(Token::EqEq).to(BinaryOp::Eq), )); - let pow = atom.clone().foldl( - pow_op.then(unary_op.clone().repeated().foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }).boxed()).repeated(), + // Postfix factorial: expr! → UnaryOp::Fac + let postfix = atom + .clone() + .foldl(just(Token::Bang).repeated(), |expr, _| Node::UnaryOp { + op: UnaryOp::Fac, + expr: Box::new(expr), + }) + .boxed(); + + let pow = postfix.clone().foldl( + pow_op + .then( + unary_op + .clone() + .repeated() + .foldr(postfix, |op, expr| Node::UnaryOp { + op, + expr: Box::new(expr), + }) + .boxed(), + ) + .repeated(), |lhs, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), op, @@ -116,11 +141,31 @@ where rhs: Box::new(rhs), }); - cmp.clone().foldl(cmp.repeated(), |lhs, rhs| Node::BinOp { + // Chain comparisons like `a < b < c` by multiplying the boolean + // (1.0 / 0.0) results, preserving the existing semantics. + let chained_cmp = cmp.clone().foldl(cmp.repeated(), |lhs, rhs| Node::BinOp { lhs: Box::new(lhs), op: BinaryOp::Mul, rhs: Box::new(rhs), - }) + }); + + let and = chained_cmp + .clone() + .foldl(and_op.then(chained_cmp).repeated(), |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); + + let or = and + .clone() + .foldl(or_op.then(and).repeated(), |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); + + or }) } diff --git a/libraries/math-parser/src/value.rs b/libraries/math-parser/src/value.rs index c7792f9819..74184c0575 100644 --- a/libraries/math-parser/src/value.rs +++ b/libraries/math-parser/src/value.rs @@ -56,6 +56,16 @@ impl Number { match (self, other) { (Number::Real(lhs), Number::Real(rhs)) => { let result = match op { + BinaryOp::And => { + let l = lhs != 0.0; + let r = rhs != 0.0; + if l && r { 1.0 } else { 0.0 } + } + BinaryOp::Or => { + let l = lhs != 0.0; + let r = rhs != 0.0; + if l || r { 1.0 } else { 0.0 } + } BinaryOp::Add => lhs + rhs, BinaryOp::Sub => lhs - rhs, BinaryOp::Mul => lhs * rhs, @@ -75,6 +85,16 @@ impl Number { (Number::Complex(lhs), Number::Complex(rhs)) => { let result = match op { + BinaryOp::And => { + let l = lhs != Complex::new(0.0, 0.0); + let r = rhs != Complex::new(0.0, 0.0); + return Some(Number::Real(if l && r { 1.0 } else { 0.0 })); + } + BinaryOp::Or => { + let l = lhs != Complex::new(0.0, 0.0); + let r = rhs != Complex::new(0.0, 0.0); + return Some(Number::Real(if l || r { 1.0 } else { 0.0 })); + } BinaryOp::Add => lhs + rhs, BinaryOp::Sub => lhs - rhs, BinaryOp::Mul => lhs * rhs, @@ -135,15 +155,37 @@ impl Number { Number::Real(real) => match op { UnaryOp::Neg => Number::Real(-real), UnaryOp::Sqrt => Number::Real(real.sqrt()), - - UnaryOp::Fac => todo!("Implement factorial"), + UnaryOp::Fac => { + // n! for real n: use integer semantics when n is a + // non-negative integer, otherwise return NaN. + if !real.is_finite() { + return Number::Real(f64::NAN); + } + let truncated = real.trunc(); + if truncated < 0.0 || (real - truncated).abs() > f64::EPSILON { + return Number::Real(f64::NAN); + } + let n = truncated as u64; + let mut acc = 1.0_f64; + for k in 1..=n { + acc *= k as f64; + } + Number::Real(acc) + } + UnaryOp::Not => { + let is_zero = real == 0.0; + Number::Real(if is_zero { 1.0 } else { 0.0 }) + } }, Number::Complex(complex) => match op { UnaryOp::Neg => Number::Complex(-complex), UnaryOp::Sqrt => Number::Complex(complex.sqrt()), - - UnaryOp::Fac => todo!("Implement factorial"), + UnaryOp::Fac => Number::Complex(Complex::new(f64::NAN, f64::NAN)), + UnaryOp::Not => { + let is_zero = complex == Complex::new(0.0, 0.0); + Number::Real(if is_zero { 1.0 } else { 0.0 }) + } }, } } From c2f0c1167eb0bc242ad666f255dd1d1fd2e51134 Mon Sep 17 00:00:00 2001 From: Keavon Chambers Date: Wed, 1 Apr 2026 19:44:39 -0700 Subject: [PATCH 16/16] Format --- libraries/math-parser/src/parser.rs | 40 +++++++++-------------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index fc88e6aa5c..792c3be46a 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -74,10 +74,7 @@ where let add_op = choice((just(Token::Plus).to(BinaryOp::Add), just(Token::Minus).to(BinaryOp::Sub))); let mul_op = choice((just(Token::Star).to(BinaryOp::Mul), just(Token::Slash).to(BinaryOp::Div), just(Token::Modulo).to(BinaryOp::Modulo))); let pow_op = just(Token::Caret).to(BinaryOp::Pow); - let unary_op = choice(( - just(Token::Minus).to(UnaryOp::Neg), - just(Token::Bang).to(UnaryOp::Not), - )); + let unary_op = choice((just(Token::Minus).to(UnaryOp::Neg), just(Token::Bang).to(UnaryOp::Not))); let and_op = just(Token::AndAnd).to(BinaryOp::And); let or_op = just(Token::OrOr).to(BinaryOp::Or); let cmp_op = choice(( @@ -100,16 +97,7 @@ where let pow = postfix.clone().foldl( pow_op - .then( - unary_op - .clone() - .repeated() - .foldr(postfix, |op, expr| Node::UnaryOp { - op, - expr: Box::new(expr), - }) - .boxed(), - ) + .then(unary_op.clone().repeated().foldr(postfix, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }).boxed()) .repeated(), |lhs, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), @@ -149,21 +137,17 @@ where rhs: Box::new(rhs), }); - let and = chained_cmp - .clone() - .foldl(and_op.then(chained_cmp).repeated(), |lhs, (op, rhs)| Node::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - }); + let and = chained_cmp.clone().foldl(and_op.then(chained_cmp).repeated(), |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); - let or = and - .clone() - .foldl(or_op.then(and).repeated(), |lhs, (op, rhs)| Node::BinOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - }); + let or = and.clone().foldl(or_op.then(and).repeated(), |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); or })