diff --git a/src/ast.rs b/src/ast.rs index 20dd2a70..f3976569 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -10,7 +10,7 @@ use simplicity::jet::Elements; use crate::debug::{CallTracker, DebugSymbols, TrackedCallName}; use crate::error::{Error, RichError, Span, WithSpan}; -use crate::num::{NonZeroPow2Usize, Pow2Usize}; +use crate::num::{NonZeroPow2Usize, Pow2Usize, U256}; use crate::parse::MatchPattern; use crate::pattern::Pattern; use crate::str::{AliasName, FunctionName, Identifier, ModuleName, WitnessName}; @@ -957,6 +957,77 @@ impl AbstractSyntaxTree for SingleExpression { let value = Value::parse_hexadecimal(bytes, ty).with_span(from)?; SingleExpressionInner::Constant(value) } + parse::SingleExpressionInner::TypeBound(type_bound) => match type_bound { + parse::TypeBound::UInt(uint_ty, bound) => { + let int_ty = ty + .as_integer() + .ok_or(Error::ExpressionUnexpectedType(ty.clone())) + .with_span(from)?; + + if int_ty != *uint_ty { + return Err(Error::ExpressionTypeMismatch( + ty.clone(), + ResolvedType::from(*uint_ty), + )) + .with_span(from); + } + + let value = match (uint_ty, bound) { + (UIntType::U1, parse::UIntBound::Min) => Value::from(UIntValue::U1(0)), + (UIntType::U1, parse::UIntBound::Max) => Value::from(UIntValue::U1(1)), + + (UIntType::U2, parse::UIntBound::Min) => Value::from(UIntValue::U2(0)), + (UIntType::U2, parse::UIntBound::Max) => Value::from(UIntValue::U2(3)), + + (UIntType::U4, parse::UIntBound::Min) => Value::from(UIntValue::U4(0)), + (UIntType::U4, parse::UIntBound::Max) => Value::from(UIntValue::U4(15)), + + (UIntType::U8, parse::UIntBound::Min) => { + Value::from(UIntValue::U8(u8::MIN)) + } + (UIntType::U8, parse::UIntBound::Max) => { + Value::from(UIntValue::U8(u8::MAX)) + } + + (UIntType::U16, parse::UIntBound::Min) => { + Value::from(UIntValue::U16(u16::MIN)) + } + (UIntType::U16, parse::UIntBound::Max) => { + Value::from(UIntValue::U16(u16::MAX)) + } + + (UIntType::U32, parse::UIntBound::Min) => { + Value::from(UIntValue::U32(u32::MIN)) + } + (UIntType::U32, parse::UIntBound::Max) => { + Value::from(UIntValue::U32(u32::MAX)) + } + + (UIntType::U64, parse::UIntBound::Min) => { + Value::from(UIntValue::U64(u64::MIN)) + } + (UIntType::U64, parse::UIntBound::Max) => { + Value::from(UIntValue::U64(u64::MAX)) + } + + (UIntType::U128, parse::UIntBound::Min) => { + Value::from(UIntValue::U128(u128::MIN)) + } + (UIntType::U128, parse::UIntBound::Max) => { + Value::from(UIntValue::U128(u128::MAX)) + } + + (UIntType::U256, parse::UIntBound::Min) => { + Value::from(UIntValue::U256(U256::MIN)) + } + (UIntType::U256, parse::UIntBound::Max) => { + Value::from(UIntValue::U256(U256::MAX)) + } + }; + + SingleExpressionInner::Constant(value) + } + }, parse::SingleExpressionInner::Witness(name) => { scope .insert_witness(name.clone(), ty.clone()) diff --git a/src/parse.rs b/src/parse.rs index a6eebd18..793eb183 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -249,6 +249,28 @@ impl TypeAlias { impl_eq_hash!(TypeAlias; name, ty); +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub enum UIntBound { + Min, + Max, +} + +impl UIntBound { + pub const fn as_str(self) -> &'static str { + match self { + Self::Min => "MIN", + Self::Max => "MAX", + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub enum TypeBound { + UInt(UIntType, UIntBound), +} + /// An expression is something that returns a value. #[derive(Clone, Debug)] pub struct Expression { @@ -339,6 +361,8 @@ pub enum SingleExpressionInner { Binary(Binary), /// Hexadecimal string literal. Hexadecimal(Hexadecimal), + /// Constants of a type (e.g. MAX, MIN) + TypeBound(TypeBound), /// Witness value. Witness(WitnessName), /// Parameter value. @@ -633,9 +657,8 @@ impl TreeLike for ExprTree<'_> { | S::Decimal(_) | S::Hexadecimal(_) | S::Variable(_) - | S::Witness(_) - | S::Parameter(_) - | S::Option(None) => Tree::Nullary, + | S::TypeBound(_) => Tree::Nullary, + S::Witness(_) | S::Parameter(_) | S::Option(None) => Tree::Nullary, S::Option(Some(l)) | S::Either(Either::Left(l)) | S::Either(Either::Right(l)) @@ -684,6 +707,9 @@ impl fmt::Display for ExprTree<'_> { S::Decimal(decimal) => write!(f, "{decimal}")?, S::Hexadecimal(hexadecimal) => write!(f, "0x{hexadecimal}")?, S::Variable(name) => write!(f, "{name}")?, + S::TypeBound(TypeBound::UInt(ty, bound)) => { + write!(f, "{ty}::{}", bound.as_str())? + } S::Witness(name) => write!(f, "witness::{name}")?, S::Parameter(name) => write!(f, "param::{name}")?, S::Option(None) => write!(f, "None")?, @@ -1606,6 +1632,34 @@ impl SingleExpression { let match_expr = Match::parser(expr.clone()).map(SingleExpressionInner::Match); + let type_bound = Identifier::parser() + .then_ignore(just(Token::DoubleColon)) + .then(Identifier::parser()) + .try_map( + |(lhs, rhs), span| match UIntType::from_str(lhs.as_inner()) { + Ok(ty) => Ok((ty, rhs)), + // this is a fall through, this error is not emitted + Err(_) => Err(Error::Grammar("not a type bound".into()).with_span(span)), + }, + ) + .validate(|(ty, rhs), e, emit| { + let bound = match rhs.as_inner() { + "MIN" => UIntBound::Min, + "MAX" => UIntBound::Max, + _ => { + // Send the error through here + emit.emit( + Error::Grammar( + "Expected `MIN` or `MAX` after unsigned integer type".into(), + ) + .with_span(e.span()), + ); + UIntBound::Min + } + }; + SingleExpressionInner::TypeBound(TypeBound::UInt(ty, bound)) + }); + let variable = Identifier::parser().map(SingleExpressionInner::Variable); // Expression delimeted by parentheses @@ -1616,7 +1670,7 @@ impl SingleExpression { choice(( left, right, some, none, boolean, match_expr, expression, list, array, tuple, call, - literal, variable, + literal, type_bound, variable, )) .map_with(|inner, e| Self { inner, @@ -2193,7 +2247,10 @@ mod test { let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler); assert!(parse_program.is_none()); - assert!(ErrorCollector::to_string(&error_handler).contains("Expected '::', found ':'")); + let errors = ErrorCollector::to_string(&error_handler); + + assert!(parse_program.is_none()); + assert!(errors.contains("::"), "{errors}"); } #[test] @@ -2203,6 +2260,119 @@ mod test { let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler); assert!(parse_program.is_none()); - assert!(ErrorCollector::to_string(&error_handler).contains("Expected ';', found '::'")); + let errors = ErrorCollector::to_string(&error_handler); + + assert!(parse_program.is_none()); + assert!(errors.contains("::"), "{errors}"); + } + + #[test] + fn invalid_input_falls_through_type_bound_on_try_match() { + let input = "fn main() { let pk: Pubkey = witnes::PK; }"; + let mut error_handler = ErrorCollector::new(Arc::from(input)); + let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler); + + assert!(parse_program.is_none()); + + let error_str = ErrorCollector::to_string(&error_handler); + + assert!( + error_str.contains("Expected ';', found '::'"), + "{}", + error_str + ); + + assert!( + !error_str.contains("Expected unsigned integer type before `::`"), + "{}", + error_str + ); + + assert!( + !error_str.contains("Expected `MIN` or `MAX` after unsigned integer type"), + "{}", + error_str + ); + } + + #[test] + fn valid_input_falls_through_type_bound_when_not_match() { + let input = "fn main() { let pk: Pubkey = Witness::PK; }"; + let mut error_handler = ErrorCollector::new(Arc::from(input)); + let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler); + + assert!(parse_program.is_none()); + + let error_str = ErrorCollector::to_string(&error_handler); + + assert!( + error_str.contains("Expected ';', found '::'"), + "{}", + error_str + ); + + assert!( + !error_str.contains("Expected unsigned integer type before `::`"), + "{}", + error_str + ); + + assert!( + !error_str.contains("Expected `MIN` or `MAX` after unsigned integer type"), + "{}", + error_str + ); + } + + #[test] + fn parses_u8_min_max_invalid_reports_error() { + let input = "fn main() { let a: u8 = u8::MI; let b: u8 = u8::MA; }"; + let mut error_handler = ErrorCollector::new(Arc::from(input)); + let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler); + let error_str = ErrorCollector::to_string(&error_handler); + + assert!(parse_program.is_none()); + + assert!( + error_str.contains("Expected `MIN` or `MAX` after unsigned integer type"), + "{}", + error_str + ); + } + + #[test] + fn parses_u8_min_max() { + let src = "fn main() { let a: u8 = u8::MIN; let b: u8 = u8::MAX; }"; + let program = Program::parse_from_str(src).expect("should parse"); + let rendered = program.to_string(); + assert!(rendered.contains("u8::MIN")); + assert!(rendered.contains("u8::MAX")); + } + + #[test] + fn parses_all_uint_min_max() { + let types = ["u1", "u2", "u4", "u8", "u16", "u32", "u64", "u128", "u256"]; + + for ty in types { + let src = format!( + "fn main() {{ let a: {t} = {t}::MIN; let b: {t} = {t}::MAX; }}", + t = ty + ); + + let program = Program::parse_from_str(&src) + .unwrap_or_else(|e| panic!("failed to parse for {ty}: {e:?}")); + + let rendered = program.to_string(); + + assert!( + rendered.contains(&format!("{ty}::MIN")), + "missing MIN for {ty}" + ); + + assert!( + rendered.contains(&format!("{ty}::MAX")), + "missing MAX for {ty}" + ); + } } }