Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 111 additions & 34 deletions src/ast.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};
use std::num::NonZeroUsize;
use std::str::FromStr;
use std::sync::Arc;
Expand All @@ -9,6 +9,7 @@ use miniscript::iter::{Tree, TreeLike};
use simplicity::jet::Elements;

use crate::debug::{CallTracker, DebugSymbols, TrackedCallName};
use crate::driver::resolve_order::{AliasRegistry, ItemNameWithFileId};
use crate::error::{Error, RichError, Span, WithSpan};
use crate::num::{NonZeroPow2Usize, Pow2Usize};
use crate::parse::MatchPattern;
Expand All @@ -19,7 +20,7 @@ use crate::types::{
};
use crate::value::{UIntValue, Value};
use crate::witness::{Parameters, WitnessTypes, WitnessValues};
use crate::{impl_eq_hash, parse};
use crate::{driver, impl_eq_hash, parse};

/// A program consists of the main function.
///
Expand Down Expand Up @@ -520,18 +521,47 @@ impl TreeLike for ExprTree<'_> {
/// 2. Resolving type aliases
/// 3. Assigning types to each witness expression
/// 4. Resolving calls to custom functions
#[derive(Clone, Debug, Eq, PartialEq, Default)]
#[derive(Clone, Debug, Eq, PartialEq)]
struct Scope {
resolutions: Arc<[BTreeSet<Arc<str>>]>,
import_aliases: AliasRegistry,
file_id: usize,

variables: Vec<HashMap<Identifier, ResolvedType>>,
aliases: HashMap<AliasName, ResolvedType>,
parameters: HashMap<WitnessName, ResolvedType>,
witnesses: HashMap<WitnessName, ResolvedType>,
functions: HashMap<FunctionName, CustomFunction>,
functions: HashMap<ItemNameWithFileId, CustomFunction>,
is_main: bool,
call_tracker: CallTracker,
}

impl Default for Scope {
fn default() -> Self {
Self::new(Arc::from([]), AliasRegistry::default())
}
}

impl Scope {
pub fn new(resolutions: Arc<[BTreeSet<Arc<str>>]>, import_aliases: AliasRegistry) -> Self {
Self {
resolutions,
import_aliases,
file_id: 0,
variables: Vec::new(),
aliases: HashMap::new(),
parameters: HashMap::new(),
witnesses: HashMap::new(),
functions: HashMap::new(),
is_main: false,
call_tracker: CallTracker::default(),
}
}

pub fn file_id(&self) -> usize {
self.file_id
}

/// Check if the current scope is topmost.
pub fn is_topmost(&self) -> bool {
self.variables.is_empty()
Expand All @@ -542,6 +572,11 @@ impl Scope {
self.variables.push(HashMap::new());
}

pub fn push_function_scope(&mut self, file_id: usize) {
self.push_scope();
self.file_id = file_id;
}

/// Push the scope of the main function onto the stack.
///
/// ## Panics
Expand All @@ -564,6 +599,11 @@ impl Scope {
self.variables.pop().expect("Stack is empty");
}

pub fn pop_function_scope(&mut self, previous_file_id: usize) {
self.pop_scope();
self.file_id = previous_file_id;
}

/// Pop the scope of the main function from the stack.
///
/// ## Panics
Expand Down Expand Up @@ -682,20 +722,66 @@ impl Scope {
pub fn insert_function(
&mut self,
name: FunctionName,
file_id: usize,
function: CustomFunction,
) -> Result<(), Error> {
match self.functions.entry(name.clone()) {
Entry::Occupied(_) => Err(Error::FunctionRedefined(name)),
Entry::Vacant(entry) => {
entry.insert(function);
Ok(())
}
let global_id = (Arc::from(name.as_inner()), file_id);

if self.functions.contains_key(&global_id) {
return Err(Error::FunctionRedefined(name));
}

let _ = self.functions.insert(global_id, function);
Ok(())
}

/// Get the definition of a custom function.
pub fn get_function(&self, name: &FunctionName) -> Option<&CustomFunction> {
self.functions.get(name)
// NOTE: Why do we use this function to retrieve a `TypeAlias`?

/// Retrieves the definition of a custom function, enforcing strict error prioritization.
///
/// # Architecture Note
/// The order of operations here is intentional to prioritize specific compiler errors:
/// 1. Resolve the alias to find the true global coordinates.
/// 2. Check for global existence (`FunctionUndefined`) *before* checking local visibility.
/// 3. Verify if the current file's scope is actually allowed to see it (`PrivateItem`).
///
/// # Errors
///
/// * [`Error::FunctionUndefined`]: The function is not found in the global registry.
/// * [`Error::Internal`]: The specified `file_id` does not exist in the `files`.
/// * [`Error::PrivateItem`]: The function exists globally but is not exposed to the current file's scope.
pub fn get_function(&self, name: &FunctionName) -> Result<&CustomFunction, Error> {
// 1. Get the true global ID of the alias (or keep the current name if it is not aliased).
let initial_id = (Arc::from(name.as_inner()), self.file_id);
let global_id = self
.import_aliases
.resolved_roots()
.get(&initial_id)
.cloned()
.unwrap_or(initial_id);

// 2. Fetch the function from the global pool.
// We do this first so we can throw FunctionUndefined before checking local visibility.
let function = self
.functions
.get(&global_id)
.ok_or_else(|| Error::FunctionUndefined(name.clone()))?;

// TODO: Consider changing it to a better error handler with a source file.
let file_scope = self.resolutions.get(self.file_id).ok_or_else(|| {
Error::Internal(format!(
"file_id {} not found inside current Scope files",
self.file_id
))
})?;

// 3. Verify local scope visibility.
// We successfully found the function globally, but is this file allowed to use it?
if file_scope.contains(&Arc::from(name.as_inner())) {
Ok(function)
} else {
Err(Error::PrivateItem(name.as_inner().to_string()))
}
}

/// Track a call expression with its span.
Expand All @@ -718,9 +804,10 @@ trait AbstractSyntaxTree: Sized {
}

impl Program {
pub fn analyze(from: &parse::Program) -> Result<Self, RichError> {
pub fn analyze(from: &driver::resolve_order::Program) -> Result<Self, RichError> {
let unit = ResolvedType::unit();
let mut scope = Scope::default();
let mut scope = Scope::new(Arc::from(from.resolutions()), from.import_aliases().clone());

let items = from
.items()
.iter()
Expand All @@ -732,6 +819,7 @@ impl Program {
Item::Function(Function::Main(expr)) => Some(expr),
_ => None,
});

let main = iter.next().ok_or(Error::MainRequired).with_span(from)?;
if iter.next().is_some() {
return Err(Error::FunctionRedefined(FunctionName::main())).with_span(from);
Expand Down Expand Up @@ -777,8 +865,10 @@ impl AbstractSyntaxTree for Function {
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
assert!(ty.is_unit(), "Function definitions cannot return anything");
assert!(scope.is_topmost(), "Items live in the topmost scope only");
let previous_file_id = scope.file_id();

if from.name().as_inner() != "main" {
let file_id = from.file_id();
let params = from
.params()
.iter()
Expand All @@ -795,16 +885,16 @@ impl AbstractSyntaxTree for Function {
.map(|aliased| scope.resolve(aliased).with_span(from))
.transpose()?
.unwrap_or_else(ResolvedType::unit);
scope.push_scope();
scope.push_function_scope(file_id);
for param in params.iter() {
scope.insert_variable(param.identifier().clone(), param.ty().clone());
}
let body = Expression::analyze(from.body(), &ret, scope).map(Arc::new)?;
scope.pop_scope();
scope.pop_function_scope(previous_file_id);
debug_assert!(scope.is_topmost());
let function = CustomFunction { params, body };
scope
.insert_function(from.name().clone(), function)
.insert_function(from.name().clone(), file_id, function)
.with_span(from)?;

return Ok(Self::Custom);
Expand Down Expand Up @@ -1325,14 +1415,9 @@ impl AbstractSyntaxTree for CallName {
.get_function(name)
.cloned()
.map(Self::Custom)
.ok_or(Error::FunctionUndefined(name.clone()))
.with_span(from),
parse::CallName::ArrayFold(name, size) => {
let function = scope
.get_function(name)
.cloned()
.ok_or(Error::FunctionUndefined(name.clone()))
.with_span(from)?;
let function = scope.get_function(name).cloned().with_span(from)?;
// A function that is used in a array fold has the signature:
// fn f(element: E, accumulator: A) -> A
if function.params().len() != 2 || function.params()[1].ty() != function.body().ty()
Expand All @@ -1343,11 +1428,7 @@ impl AbstractSyntaxTree for CallName {
}
}
parse::CallName::Fold(name, bound) => {
let function = scope
.get_function(name)
.cloned()
.ok_or(Error::FunctionUndefined(name.clone()))
.with_span(from)?;
let function = scope.get_function(name).cloned().with_span(from)?;
// A function that is used in a list fold has the signature:
// fn f(element: E, accumulator: A) -> A
if function.params().len() != 2 || function.params()[1].ty() != function.body().ty()
Expand All @@ -1358,11 +1439,7 @@ impl AbstractSyntaxTree for CallName {
}
}
parse::CallName::ForWhile(name) => {
let function = scope
.get_function(name)
.cloned()
.ok_or(Error::FunctionUndefined(name.clone()))
.with_span(from)?;
let function = scope.get_function(name).cloned().with_span(from)?;
// A function that is used in a for-while loop has the signature:
// fn f(accumulator: A, readonly_context: C, counter: u{N}) -> Either<B, A>
// where
Expand Down
2 changes: 1 addition & 1 deletion src/driver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
//! the dependency graph construction.

mod linearization;
mod resolve_order;
pub(crate) mod resolve_order;

use std::collections::{HashMap, HashSet, VecDeque};
use std::path::PathBuf;
Expand Down
Loading
Loading