Skip to content

Commit b420dbc

Browse files
committed
fix benches and improve compile time
1 parent b954b22 commit b420dbc

2 files changed

Lines changed: 28 additions & 309 deletions

File tree

libraries/math-parser/benches/bench.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ macro_rules! generate_benchmarks {
1616

1717
fn evaluation_bench(c: &mut Criterion) {
1818
$(
19-
let expr = ast::Node::try_parse_from_str($input).unwrap().0;
19+
let expr = ast::Node::try_parse_from_str($input).unwrap();
2020
let context = EvalContext::default();
2121

2222
c.bench_function(concat!("eval ", $input), |b| {

libraries/math-parser/src/parser.rs

Lines changed: 27 additions & 308 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,26 @@ use num_complex::ComplexFloat;
88
use std::num::{ParseFloatError, ParseIntError};
99
use thiserror::Error;
1010

11-
/*#[derive(Parser)]
12-
#[grammar = "./grammer.pest"]
13-
struct ExprParser;
14-
15-
lazy_static! {
16-
static ref PRATT_PARSER: PrattParser<Rule> = {
17-
PrattParser::new()
18-
.op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::sub, Assoc::Left))
19-
.op(Op::infix(Rule::mul, Assoc::Left) | Op::infix(Rule::div, Assoc::Left) | Op::infix(Rule::paren, Assoc::Left))
20-
.op(Op::infix(Rule::pow, Assoc::Right))
21-
.op(Op::infix(Rule::geq, Assoc::Left) | Op::infix(Rule::leq, Assoc::Left) | Op::infix(Rule::eq, Assoc::Left))
22-
.op(Op::postfix(Rule::EOI))
23-
.op(Op::prefix(Rule::sqrt))
24-
.op(Op::prefix(Rule::neg))
25-
};
26-
}*/
27-
28-
#[derive(Error, Debug)]
29-
pub enum TypeError {
30-
#[error("Invalid BinOp: {0:?} {1:?} {2:?}")]
31-
InvalidBinaryOp(Unit, BinaryOp, Unit),
32-
33-
#[error("Invalid UnaryOp: {0:?}")]
34-
InvalidUnaryOp(Unit, UnaryOp),
35-
}
36-
3711
#[derive(Error, Debug)]
3812
pub enum ParseError<'src> {
39-
/// One or more syntax/lexing errors produced by Chumsky.
4013
#[error("Syntax error(s): {0:#?}")]
4114
Syntax(Vec<Rich<'src, char>>),
15+
#[error("nig")]
16+
Nigga(),
17+
}
18+
19+
#[cfg(not(debug_assertions))] // In release build: skip boxing
20+
macro_rules! maybe_boxed {
21+
($p:expr) => {
22+
$p
23+
};
24+
}
25+
26+
#[cfg(debug_assertions)] // In debug build: box for faster compilation
27+
macro_rules! maybe_boxed {
28+
($p:expr) => {
29+
$p.boxed()
30+
};
4231
}
4332

4433
impl Node {
@@ -109,7 +98,7 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er
10998
else_block: Box::new(else_b),
11099
});
111100

112-
let atom = choice((conditional, float, constant, call, parens, var));
101+
let atom = maybe_boxed!(choice((conditional, float, constant, call, parens, var)));
113102

114103
let unary = choice((just('-').to(UnaryOp::Neg), just("sqrt").to(UnaryOp::Sqrt)))
115104
.padded()
@@ -122,15 +111,16 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er
122111
rhs: Box::new(rhs),
123112
});
124113

125-
let product = pow
126-
.clone()
127-
.foldl(choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div))).padded().then(pow).repeated(), |lhs, (op, rhs)| {
128-
Node::BinOp {
129-
lhs: Box::new(lhs),
130-
op,
131-
rhs: Box::new(rhs),
132-
}
133-
});
114+
let product = maybe_boxed!(
115+
pow.clone()
116+
.foldl(choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div))).padded().then(pow).repeated(), |lhs, (op, rhs)| {
117+
Node::BinOp {
118+
lhs: Box::new(lhs),
119+
op,
120+
rhs: Box::new(rhs),
121+
}
122+
})
123+
);
134124

135125
let sum = product.clone().foldl(
136126
choice((just('+').to(BinaryOp::Add), just('-').to(BinaryOp::Sub))).padded().then(product).repeated(),
@@ -157,277 +147,6 @@ pub fn chumsky_parser<'a>() -> impl Parser<'a, &'a str, Node, chumsky::extra::Er
157147
})
158148
}
159149

160-
struct NodeMetadata {
161-
pub unit: Unit,
162-
}
163-
164-
impl NodeMetadata {
165-
pub fn new(unit: Unit) -> Self {
166-
Self { unit }
167-
}
168-
}
169-
170-
/*fn parse_unit(pairs: Pairs<Rule>) -> Result<(Unit, f64), ParseError> {
171-
let mut scale = 1.0;
172-
let mut length = 0;
173-
let mut mass = 0;
174-
let mut time = 0;
175-
176-
for pair in pairs {
177-
println!("found rule: {:?}", pair.as_rule());
178-
match pair.as_rule() {
179-
Rule::nano => scale *= 1e-9,
180-
Rule::micro => scale *= 1e-6,
181-
Rule::milli => scale *= 1e-3,
182-
Rule::centi => scale *= 1e-2,
183-
Rule::deci => scale *= 1e-1,
184-
Rule::deca => scale *= 1e1,
185-
Rule::hecto => scale *= 1e2,
186-
Rule::kilo => scale *= 1e3,
187-
Rule::mega => scale *= 1e6,
188-
Rule::giga => scale *= 1e9,
189-
Rule::tera => scale *= 1e12,
190-
191-
Rule::meter => length = 1,
192-
Rule::gram => mass = 1,
193-
Rule::second => time = 1,
194-
195-
_ => unreachable!(), // All possible rules should be covered
196-
}
197-
}
198-
199-
Ok((Unit { length, mass, time }, scale))
200-
}
201-
202-
fn parse_const(pair: Pair<Rule>) -> Literal {
203-
match pair.as_rule() {
204-
Rule::infinity => Literal::Float(f64::INFINITY),
205-
Rule::imaginary_unit => Literal::Complex(Complex::new(0.0, 1.0)),
206-
Rule::pi => Literal::Float(std::f64::consts::PI),
207-
Rule::tau => Literal::Float(2.0 * std::f64::consts::PI),
208-
Rule::euler_number => Literal::Float(std::f64::consts::E),
209-
Rule::golden_ratio => Literal::Float(1.61803398875),
210-
_ => unreachable!("Unexpected constant: {:?}", pair),
211-
}
212-
}
213-
214-
fn parse_lit(mut pairs: Pairs<Rule>) -> Result<(Literal, Unit), ParseError> {
215-
let literal = match pairs.next() {
216-
Some(lit) => match lit.as_rule() {
217-
Rule::int => {
218-
let value = lit.as_str().parse::<i32>()? as f64;
219-
Literal::Float(value)
220-
}
221-
Rule::float => {
222-
let value = lit.as_str().parse::<f64>()?;
223-
Literal::Float(value)
224-
}
225-
Rule::unit => {
226-
let (unit, scale) = parse_unit(lit.into_inner())?;
227-
return Ok((Literal::Float(scale), unit));
228-
}
229-
rule => unreachable!("unexpected rule: {:?}", rule),
230-
},
231-
None => unreachable!("expected rule"), // No literal found
232-
};
233-
234-
if let Some(unit_pair) = pairs.next() {
235-
let unit_pairs = unit_pair.into_inner(); // Get the inner pairs for the unit
236-
let (unit, scale) = parse_unit(unit_pairs)?;
237-
238-
println!("found unit: {:?}", unit);
239-
240-
Ok((
241-
match literal {
242-
Literal::Float(num) => Literal::Float(num * scale),
243-
Literal::Complex(num) => Literal::Complex(num * scale),
244-
},
245-
unit,
246-
))
247-
} else {
248-
Ok((literal, Unit::BASE_UNIT))
249-
}
250-
}
251-
252-
fn parse_expr(pairs: Pairs<Rule>) -> Result<(Node, NodeMetadata), ParseError> {
253-
PRATT_PARSER
254-
.map_primary(|primary| {
255-
Ok(match primary.as_rule() {
256-
Rule::lit => {
257-
let (lit, unit) = parse_lit(primary.into_inner())?;
258-
259-
(Node::Lit(lit), NodeMetadata { unit })
260-
}
261-
Rule::fn_call => {
262-
let mut pairs = primary.into_inner();
263-
let name = pairs.next().expect("fn_call always has 2 children").as_str().to_string();
264-
265-
(
266-
Node::FnCall {
267-
name,
268-
expr: pairs.map(|p| parse_expr(p.into_inner()).map(|expr| expr.0)).collect::<Result<Vec<Node>, ParseError>>()?,
269-
},
270-
NodeMetadata::new(Unit::BASE_UNIT),
271-
)
272-
}
273-
Rule::constant => {
274-
let lit = parse_const(primary.into_inner().next().expect("constant should have atleast 1 child"));
275-
276-
(Node::Lit(lit), NodeMetadata::new(Unit::BASE_UNIT))
277-
}
278-
Rule::ident => {
279-
let name = primary.as_str().to_string();
280-
281-
(Node::Var(name), NodeMetadata::new(Unit::BASE_UNIT))
282-
}
283-
Rule::expr => parse_expr(primary.into_inner())?,
284-
Rule::float => {
285-
let value = primary.as_str().parse::<f64>()?;
286-
(Node::Lit(Literal::Float(value)), NodeMetadata::new(Unit::BASE_UNIT))
287-
}
288-
Rule::conditional => {
289-
let mut inner = primary.into_inner();
290-
291-
let condition = parse_expr(inner.next().expect("expected conditional to have condition").into_inner())?;
292-
let if_block = parse_expr(inner.next().expect("expected conditional to have if block").into_inner())?;
293-
let else_block = parse_expr(inner.next().expect("expected conditional to have else block").into_inner())?;
294-
295-
(
296-
Node::Conditional {
297-
condition: Box::new(condition.0),
298-
if_block: Box::new(if_block.0),
299-
else_block: Box::new(else_block.0),
300-
},
301-
if_block.1,
302-
)
303-
}
304-
rule => unreachable!("unexpected rule: {:?}", rule),
305-
})
306-
})
307-
.map_prefix(|op, rhs| {
308-
let (rhs, rhs_metadata) = rhs?;
309-
let op = match op.as_rule() {
310-
Rule::neg => UnaryOp::Neg,
311-
Rule::sqrt => UnaryOp::Sqrt,
312-
313-
rule => unreachable!("unexpected rule: {:?}", rule),
314-
};
315-
316-
let node = Node::UnaryOp { expr: Box::new(rhs), op };
317-
let unit = rhs_metadata.unit;
318-
319-
let unit = if !unit.is_base() {
320-
match op {
321-
UnaryOp::Sqrt if unit.length % 2 == 0 && unit.mass % 2 == 0 && unit.time % 2 == 0 => Unit {
322-
length: unit.length / 2,
323-
mass: unit.mass / 2,
324-
time: unit.time / 2,
325-
},
326-
UnaryOp::Neg => unit,
327-
op => return Err(ParseError::Type(TypeError::InvalidUnaryOp(unit, op))),
328-
}
329-
} else {
330-
Unit::BASE_UNIT
331-
};
332-
333-
Ok((node, NodeMetadata::new(unit)))
334-
})
335-
.map_postfix(|lhs, op| {
336-
let op = match op.as_rule() {
337-
Rule::EOI => return lhs,
338-
rule => unreachable!("unexpected rule: {:?}", rule),
339-
};
340-
})
341-
.map_infix(|lhs, op, rhs| {
342-
let (lhs, lhs_metadata) = lhs?;
343-
let (rhs, rhs_metadata) = rhs?;
344-
345-
let op = match op.as_rule() {
346-
Rule::add => BinaryOp::Add,
347-
Rule::sub => BinaryOp::Sub,
348-
Rule::mul => BinaryOp::Mul,
349-
Rule::div => BinaryOp::Div,
350-
Rule::pow => BinaryOp::Pow,
351-
Rule::paren => BinaryOp::Mul,
352-
Rule::leq => BinaryOp::Leq,
353-
Rule::geq => BinaryOp::Geq,
354-
Rule::eq => BinaryOp::Eq,
355-
rule => unreachable!("unexpected rule: {:?}", rule),
356-
};
357-
358-
let (lhs_unit, rhs_unit) = (lhs_metadata.unit, rhs_metadata.unit);
359-
360-
let unit = match (!lhs_unit.is_base(), !rhs_unit.is_base()) {
361-
(true, true) => match op {
362-
BinaryOp::Mul => Unit {
363-
length: lhs_unit.length + rhs_unit.length,
364-
mass: lhs_unit.mass + rhs_unit.mass,
365-
time: lhs_unit.time + rhs_unit.time,
366-
},
367-
BinaryOp::Div => Unit {
368-
length: lhs_unit.length - rhs_unit.length,
369-
mass: lhs_unit.mass - rhs_unit.mass,
370-
time: lhs_unit.time - rhs_unit.time,
371-
},
372-
BinaryOp::Add | BinaryOp::Sub => {
373-
if lhs_unit == rhs_unit {
374-
lhs_unit
375-
} else {
376-
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit)));
377-
}
378-
}
379-
BinaryOp::Pow => {
380-
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit)));
381-
}
382-
BinaryOp::Eq | BinaryOp::Leq | BinaryOp::Geq => {
383-
if lhs_unit == rhs_unit {
384-
Unit::base_unit()
385-
} else {
386-
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit)));
387-
}
388-
}
389-
},
390-
391-
(true, false) => match op {
392-
BinaryOp::Add | BinaryOp::Sub => return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))),
393-
BinaryOp::Pow => {
394-
//TODO: improve error type
395-
//TODO: support fractions
396-
if let Ok(Value::Number(Number::Real(val))) = rhs.eval(&EvalContext::default()) {
397-
if (val - val as i32 as f64).abs() <= f64::EPSILON {
398-
Unit {
399-
length: lhs_unit.length * val.round() as i32,
400-
mass: lhs_unit.mass * val.round() as i32,
401-
time: lhs_unit.time * val as i32,
402-
}
403-
} else {
404-
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT)));
405-
}
406-
} else {
407-
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT)));
408-
}
409-
}
410-
_ => lhs_unit,
411-
},
412-
(false, true) => match op {
413-
BinaryOp::Add | BinaryOp::Sub | BinaryOp::Pow => return Err(ParseError::Type(TypeError::InvalidBinaryOp(Unit::BASE_UNIT, op, rhs_unit))),
414-
_ => rhs_unit,
415-
},
416-
(false, false) => Unit::BASE_UNIT,
417-
};
418-
419-
let node = Node::BinOp {
420-
lhs: Box::new(lhs),
421-
op,
422-
rhs: Box::new(rhs),
423-
};
424-
425-
Ok((node, NodeMetadata::new(unit)))
426-
})
427-
.parse(pairs)
428-
}*/
429-
430-
//TODO: set up Unit test for Units
431150
#[cfg(test)]
432151
mod tests {
433152
use super::*;

0 commit comments

Comments
 (0)