Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use simplicity::jet::Elements;

use crate::debug::{CallTracker, DebugSymbols, TrackedCallName};
use crate::error::{Error, RichError, Span, WithSpan};
use crate::num::{NonZeroPow2Usize, Pow2Usize};
use crate::num::{NonZeroPow2Usize, Pow2Usize, U256};
use crate::parse::MatchPattern;
use crate::pattern::Pattern;
use crate::str::{AliasName, FunctionName, Identifier, ModuleName, WitnessName};
Expand Down Expand Up @@ -957,6 +957,77 @@ impl AbstractSyntaxTree for SingleExpression {
let value = Value::parse_hexadecimal(bytes, ty).with_span(from)?;
SingleExpressionInner::Constant(value)
}
parse::SingleExpressionInner::TypeBound(type_bound) => match type_bound {
parse::TypeBound::UInt(uint_ty, bound) => {
let int_ty = ty
.as_integer()
.ok_or(Error::ExpressionUnexpectedType(ty.clone()))
.with_span(from)?;

if int_ty != *uint_ty {
return Err(Error::ExpressionTypeMismatch(
ty.clone(),
ResolvedType::from(*uint_ty),
))
.with_span(from);
}

let value = match (uint_ty, bound) {
(UIntType::U1, parse::UIntBound::Min) => Value::from(UIntValue::U1(0)),
(UIntType::U1, parse::UIntBound::Max) => Value::from(UIntValue::U1(1)),

(UIntType::U2, parse::UIntBound::Min) => Value::from(UIntValue::U2(0)),
(UIntType::U2, parse::UIntBound::Max) => Value::from(UIntValue::U2(3)),

(UIntType::U4, parse::UIntBound::Min) => Value::from(UIntValue::U4(0)),
(UIntType::U4, parse::UIntBound::Max) => Value::from(UIntValue::U4(15)),

(UIntType::U8, parse::UIntBound::Min) => {
Value::from(UIntValue::U8(u8::MIN))
}
(UIntType::U8, parse::UIntBound::Max) => {
Value::from(UIntValue::U8(u8::MAX))
}

(UIntType::U16, parse::UIntBound::Min) => {
Value::from(UIntValue::U16(u16::MIN))
}
(UIntType::U16, parse::UIntBound::Max) => {
Value::from(UIntValue::U16(u16::MAX))
}

(UIntType::U32, parse::UIntBound::Min) => {
Value::from(UIntValue::U32(u32::MIN))
}
(UIntType::U32, parse::UIntBound::Max) => {
Value::from(UIntValue::U32(u32::MAX))
}

(UIntType::U64, parse::UIntBound::Min) => {
Value::from(UIntValue::U64(u64::MIN))
}
(UIntType::U64, parse::UIntBound::Max) => {
Value::from(UIntValue::U64(u64::MAX))
}

(UIntType::U128, parse::UIntBound::Min) => {
Value::from(UIntValue::U128(u128::MIN))
}
(UIntType::U128, parse::UIntBound::Max) => {
Value::from(UIntValue::U128(u128::MAX))
}

(UIntType::U256, parse::UIntBound::Min) => {
Value::from(UIntValue::U256(U256::MIN))
}
(UIntType::U256, parse::UIntBound::Max) => {
Value::from(UIntValue::U256(U256::MAX))
}
};

SingleExpressionInner::Constant(value)
}
},
parse::SingleExpressionInner::Witness(name) => {
scope
.insert_witness(name.clone(), ty.clone())
Expand Down
182 changes: 176 additions & 6 deletions src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,28 @@ impl TypeAlias {

impl_eq_hash!(TypeAlias; name, ty);

#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum UIntBound {
Min,
Max,
}

impl UIntBound {
pub const fn as_str(self) -> &'static str {
match self {
Self::Min => "MIN",
Self::Max => "MAX",
}
}
}

#[derive(Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum TypeBound {
UInt(UIntType, UIntBound),
}

/// An expression is something that returns a value.
#[derive(Clone, Debug)]
pub struct Expression {
Expand Down Expand Up @@ -339,6 +361,8 @@ pub enum SingleExpressionInner {
Binary(Binary),
/// Hexadecimal string literal.
Hexadecimal(Hexadecimal),
/// Constants of a type (e.g. MAX, MIN)
TypeBound(TypeBound),
/// Witness value.
Witness(WitnessName),
/// Parameter value.
Expand Down Expand Up @@ -633,9 +657,8 @@ impl TreeLike for ExprTree<'_> {
| S::Decimal(_)
| S::Hexadecimal(_)
| S::Variable(_)
| S::Witness(_)
| S::Parameter(_)
| S::Option(None) => Tree::Nullary,
| S::TypeBound(_) => Tree::Nullary,
S::Witness(_) | S::Parameter(_) | S::Option(None) => Tree::Nullary,
S::Option(Some(l))
| S::Either(Either::Left(l))
| S::Either(Either::Right(l))
Expand Down Expand Up @@ -684,6 +707,9 @@ impl fmt::Display for ExprTree<'_> {
S::Decimal(decimal) => write!(f, "{decimal}")?,
S::Hexadecimal(hexadecimal) => write!(f, "0x{hexadecimal}")?,
S::Variable(name) => write!(f, "{name}")?,
S::TypeBound(TypeBound::UInt(ty, bound)) => {
write!(f, "{ty}::{}", bound.as_str())?
}
S::Witness(name) => write!(f, "witness::{name}")?,
S::Parameter(name) => write!(f, "param::{name}")?,
S::Option(None) => write!(f, "None")?,
Expand Down Expand Up @@ -1606,6 +1632,34 @@ impl SingleExpression {

let match_expr = Match::parser(expr.clone()).map(SingleExpressionInner::Match);

let type_bound = Identifier::parser()
.then_ignore(just(Token::DoubleColon))
.then(Identifier::parser())
.try_map(
|(lhs, rhs), span| match UIntType::from_str(lhs.as_inner()) {
Ok(ty) => Ok((ty, rhs)),
// this is a fall through, this error is not emitted
Err(_) => Err(Error::Grammar("not a type bound".into()).with_span(span)),
},
)
.validate(|(ty, rhs), e, emit| {
let bound = match rhs.as_inner() {
"MIN" => UIntBound::Min,
"MAX" => UIntBound::Max,
_ => {
// Send the error through here
emit.emit(
Error::Grammar(
"Expected `MIN` or `MAX` after unsigned integer type".into(),
)
.with_span(e.span()),
);
UIntBound::Min
}
};
SingleExpressionInner::TypeBound(TypeBound::UInt(ty, bound))
});

let variable = Identifier::parser().map(SingleExpressionInner::Variable);

// Expression delimeted by parentheses
Expand All @@ -1616,7 +1670,7 @@ impl SingleExpression {

choice((
left, right, some, none, boolean, match_expr, expression, list, array, tuple, call,
literal, variable,
literal, type_bound, variable,
))
.map_with(|inner, e| Self {
inner,
Expand Down Expand Up @@ -2193,7 +2247,10 @@ mod test {
let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler);

assert!(parse_program.is_none());
assert!(ErrorCollector::to_string(&error_handler).contains("Expected '::', found ':'"));
let errors = ErrorCollector::to_string(&error_handler);

assert!(parse_program.is_none());
assert!(errors.contains("::"), "{errors}");
}

#[test]
Expand All @@ -2203,6 +2260,119 @@ mod test {
let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler);

assert!(parse_program.is_none());
assert!(ErrorCollector::to_string(&error_handler).contains("Expected ';', found '::'"));
let errors = ErrorCollector::to_string(&error_handler);

assert!(parse_program.is_none());
assert!(errors.contains("::"), "{errors}");
}

#[test]
fn invalid_input_falls_through_type_bound_on_try_match() {
let input = "fn main() { let pk: Pubkey = witnes::PK; }";
let mut error_handler = ErrorCollector::new(Arc::from(input));
let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler);

assert!(parse_program.is_none());

let error_str = ErrorCollector::to_string(&error_handler);

assert!(
error_str.contains("Expected ';', found '::'"),
"{}",
error_str
);

assert!(
!error_str.contains("Expected unsigned integer type before `::`"),
"{}",
error_str
);

assert!(
!error_str.contains("Expected `MIN` or `MAX` after unsigned integer type"),
"{}",
error_str
);
}

#[test]
fn valid_input_falls_through_type_bound_when_not_match() {
let input = "fn main() { let pk: Pubkey = Witness::PK; }";
let mut error_handler = ErrorCollector::new(Arc::from(input));
let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler);

assert!(parse_program.is_none());

let error_str = ErrorCollector::to_string(&error_handler);

assert!(
error_str.contains("Expected ';', found '::'"),
"{}",
error_str
);

assert!(
!error_str.contains("Expected unsigned integer type before `::`"),
"{}",
error_str
);

assert!(
!error_str.contains("Expected `MIN` or `MAX` after unsigned integer type"),
"{}",
error_str
);
}

#[test]
fn parses_u8_min_max_invalid_reports_error() {
let input = "fn main() { let a: u8 = u8::MI; let b: u8 = u8::MA; }";
let mut error_handler = ErrorCollector::new(Arc::from(input));
let parse_program = Program::parse_from_str_with_errors(input, &mut error_handler);
let error_str = ErrorCollector::to_string(&error_handler);

assert!(parse_program.is_none());

assert!(
error_str.contains("Expected `MIN` or `MAX` after unsigned integer type"),
"{}",
error_str
);
}

#[test]
fn parses_u8_min_max() {
let src = "fn main() { let a: u8 = u8::MIN; let b: u8 = u8::MAX; }";
let program = Program::parse_from_str(src).expect("should parse");
let rendered = program.to_string();
assert!(rendered.contains("u8::MIN"));
assert!(rendered.contains("u8::MAX"));
}

#[test]
fn parses_all_uint_min_max() {
let types = ["u1", "u2", "u4", "u8", "u16", "u32", "u64", "u128", "u256"];

for ty in types {
let src = format!(
"fn main() {{ let a: {t} = {t}::MIN; let b: {t} = {t}::MAX; }}",
t = ty
);

let program = Program::parse_from_str(&src)
.unwrap_or_else(|e| panic!("failed to parse for {ty}: {e:?}"));

let rendered = program.to_string();

assert!(
rendered.contains(&format!("{ty}::MIN")),
"missing MIN for {ty}"
);

assert!(
rendered.contains(&format!("{ty}::MAX")),
"missing MAX for {ty}"
);
}
}
}
Loading