Skip to content

Commit d27b895

Browse files
committed
feat: change ast.rs to match the functionality of the driver
1 parent 3a84d12 commit d27b895

6 files changed

Lines changed: 242 additions & 97 deletions

File tree

src/ast.rs

Lines changed: 111 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::collections::hash_map::Entry;
2-
use std::collections::HashMap;
2+
use std::collections::{BTreeSet, HashMap};
33
use std::num::NonZeroUsize;
44
use std::str::FromStr;
55
use std::sync::Arc;
@@ -9,6 +9,7 @@ use miniscript::iter::{Tree, TreeLike};
99
use simplicity::jet::Elements;
1010

1111
use crate::debug::{CallTracker, DebugSymbols, TrackedCallName};
12+
use crate::driver::resolve_order::{AliasRegistry, ItemNameWithFileId};
1213
use crate::error::{Error, RichError, Span, WithSpan};
1314
use crate::num::{NonZeroPow2Usize, Pow2Usize};
1415
use crate::parse::MatchPattern;
@@ -19,7 +20,7 @@ use crate::types::{
1920
};
2021
use crate::value::{UIntValue, Value};
2122
use crate::witness::{Parameters, WitnessTypes, WitnessValues};
22-
use crate::{impl_eq_hash, parse};
23+
use crate::{driver, impl_eq_hash, parse};
2324

2425
/// A program consists of the main function.
2526
///
@@ -520,18 +521,47 @@ impl TreeLike for ExprTree<'_> {
520521
/// 2. Resolving type aliases
521522
/// 3. Assigning types to each witness expression
522523
/// 4. Resolving calls to custom functions
523-
#[derive(Clone, Debug, Eq, PartialEq, Default)]
524+
#[derive(Clone, Debug, Eq, PartialEq)]
524525
struct Scope {
526+
resolutions: Arc<[BTreeSet<Arc<str>>]>,
527+
import_aliases: AliasRegistry,
528+
file_id: usize,
529+
525530
variables: Vec<HashMap<Identifier, ResolvedType>>,
526531
aliases: HashMap<AliasName, ResolvedType>,
527532
parameters: HashMap<WitnessName, ResolvedType>,
528533
witnesses: HashMap<WitnessName, ResolvedType>,
529-
functions: HashMap<FunctionName, CustomFunction>,
534+
functions: HashMap<ItemNameWithFileId, CustomFunction>,
530535
is_main: bool,
531536
call_tracker: CallTracker,
532537
}
533538

539+
impl Default for Scope {
540+
fn default() -> Self {
541+
Self::new(Arc::from([]), AliasRegistry::default())
542+
}
543+
}
544+
534545
impl Scope {
546+
pub fn new(resolutions: Arc<[BTreeSet<Arc<str>>]>, import_aliases: AliasRegistry) -> Self {
547+
Self {
548+
resolutions,
549+
import_aliases,
550+
file_id: 0,
551+
variables: Vec::new(),
552+
aliases: HashMap::new(),
553+
parameters: HashMap::new(),
554+
witnesses: HashMap::new(),
555+
functions: HashMap::new(),
556+
is_main: false,
557+
call_tracker: CallTracker::default(),
558+
}
559+
}
560+
561+
pub fn file_id(&self) -> usize {
562+
self.file_id
563+
}
564+
535565
/// Check if the current scope is topmost.
536566
pub fn is_topmost(&self) -> bool {
537567
self.variables.is_empty()
@@ -542,6 +572,11 @@ impl Scope {
542572
self.variables.push(HashMap::new());
543573
}
544574

575+
pub fn push_function_scope(&mut self, file_id: usize) {
576+
self.push_scope();
577+
self.file_id = file_id;
578+
}
579+
545580
/// Push the scope of the main function onto the stack.
546581
///
547582
/// ## Panics
@@ -564,6 +599,11 @@ impl Scope {
564599
self.variables.pop().expect("Stack is empty");
565600
}
566601

602+
pub fn pop_function_scope(&mut self, previous_file_id: usize) {
603+
self.pop_scope();
604+
self.file_id = previous_file_id;
605+
}
606+
567607
/// Pop the scope of the main function from the stack.
568608
///
569609
/// ## Panics
@@ -682,20 +722,66 @@ impl Scope {
682722
pub fn insert_function(
683723
&mut self,
684724
name: FunctionName,
725+
file_id: usize,
685726
function: CustomFunction,
686727
) -> Result<(), Error> {
687-
match self.functions.entry(name.clone()) {
688-
Entry::Occupied(_) => Err(Error::FunctionRedefined(name)),
689-
Entry::Vacant(entry) => {
690-
entry.insert(function);
691-
Ok(())
692-
}
728+
let global_id = (Arc::from(name.as_inner()), file_id);
729+
730+
if self.functions.contains_key(&global_id) {
731+
return Err(Error::FunctionRedefined(name));
693732
}
733+
734+
let _ = self.functions.insert(global_id, function);
735+
Ok(())
694736
}
695737

696-
/// Get the definition of a custom function.
697-
pub fn get_function(&self, name: &FunctionName) -> Option<&CustomFunction> {
698-
self.functions.get(name)
738+
// NOTE: Why do we use this function to retrieve a `TypeAlias`?
739+
740+
/// Retrieves the definition of a custom function, enforcing strict error prioritization.
741+
///
742+
/// # Architecture Note
743+
/// The order of operations here is intentional to prioritize specific compiler errors:
744+
/// 1. Resolve the alias to find the true global coordinates.
745+
/// 2. Check for global existence (`FunctionUndefined`) *before* checking local visibility.
746+
/// 3. Verify if the current file's scope is actually allowed to see it (`PrivateItem`).
747+
///
748+
/// # Errors
749+
///
750+
/// * [`Error::FunctionUndefined`]: The function is not found in the global registry.
751+
/// * [`Error::Internal`]: The specified `file_id` does not exist in the `files`.
752+
/// * [`Error::PrivateItem`]: The function exists globally but is not exposed to the current file's scope.
753+
pub fn get_function(&self, name: &FunctionName) -> Result<&CustomFunction, Error> {
754+
// 1. Get the true global ID of the alias (or keep the current name if it is not aliased).
755+
let initial_id = (Arc::from(name.as_inner()), self.file_id);
756+
let global_id = self
757+
.import_aliases
758+
.resolved_roots()
759+
.get(&initial_id)
760+
.cloned()
761+
.unwrap_or(initial_id);
762+
763+
// 2. Fetch the function from the global pool.
764+
// We do this first so we can throw FunctionUndefined before checking local visibility.
765+
let function = self
766+
.functions
767+
.get(&global_id)
768+
.ok_or_else(|| Error::FunctionUndefined(name.clone()))?;
769+
770+
// TODO: Consider changing it to a better error handler with a source file.
771+
let file_scope = self.resolutions.get(self.file_id).ok_or_else(|| {
772+
Error::Internal(format!(
773+
"file_id {} not found inside current Scope files",
774+
self.file_id
775+
))
776+
})?;
777+
778+
// 3. Verify local scope visibility.
779+
// We successfully found the function globally, but is this file allowed to use it?
780+
if file_scope.contains(&Arc::from(name.as_inner())) {
781+
Ok(function)
782+
} else {
783+
Err(Error::PrivateItem(name.as_inner().to_string()))
784+
}
699785
}
700786

701787
/// Track a call expression with its span.
@@ -718,9 +804,10 @@ trait AbstractSyntaxTree: Sized {
718804
}
719805

720806
impl Program {
721-
pub fn analyze(from: &parse::Program) -> Result<Self, RichError> {
807+
pub fn analyze(from: &driver::resolve_order::Program) -> Result<Self, RichError> {
722808
let unit = ResolvedType::unit();
723-
let mut scope = Scope::default();
809+
let mut scope = Scope::new(Arc::from(from.resolutions()), from.import_aliases().clone());
810+
724811
let items = from
725812
.items()
726813
.iter()
@@ -732,6 +819,7 @@ impl Program {
732819
Item::Function(Function::Main(expr)) => Some(expr),
733820
_ => None,
734821
});
822+
735823
let main = iter.next().ok_or(Error::MainRequired).with_span(from)?;
736824
if iter.next().is_some() {
737825
return Err(Error::FunctionRedefined(FunctionName::main())).with_span(from);
@@ -777,8 +865,10 @@ impl AbstractSyntaxTree for Function {
777865
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
778866
assert!(ty.is_unit(), "Function definitions cannot return anything");
779867
assert!(scope.is_topmost(), "Items live in the topmost scope only");
868+
let previous_file_id = scope.file_id();
780869

781870
if from.name().as_inner() != "main" {
871+
let file_id = from.file_id();
782872
let params = from
783873
.params()
784874
.iter()
@@ -795,16 +885,16 @@ impl AbstractSyntaxTree for Function {
795885
.map(|aliased| scope.resolve(aliased).with_span(from))
796886
.transpose()?
797887
.unwrap_or_else(ResolvedType::unit);
798-
scope.push_scope();
888+
scope.push_function_scope(file_id);
799889
for param in params.iter() {
800890
scope.insert_variable(param.identifier().clone(), param.ty().clone());
801891
}
802892
let body = Expression::analyze(from.body(), &ret, scope).map(Arc::new)?;
803-
scope.pop_scope();
893+
scope.pop_function_scope(previous_file_id);
804894
debug_assert!(scope.is_topmost());
805895
let function = CustomFunction { params, body };
806896
scope
807-
.insert_function(from.name().clone(), function)
897+
.insert_function(from.name().clone(), file_id, function)
808898
.with_span(from)?;
809899

810900
return Ok(Self::Custom);
@@ -1325,14 +1415,9 @@ impl AbstractSyntaxTree for CallName {
13251415
.get_function(name)
13261416
.cloned()
13271417
.map(Self::Custom)
1328-
.ok_or(Error::FunctionUndefined(name.clone()))
13291418
.with_span(from),
13301419
parse::CallName::ArrayFold(name, size) => {
1331-
let function = scope
1332-
.get_function(name)
1333-
.cloned()
1334-
.ok_or(Error::FunctionUndefined(name.clone()))
1335-
.with_span(from)?;
1420+
let function = scope.get_function(name).cloned().with_span(from)?;
13361421
// A function that is used in a array fold has the signature:
13371422
// fn f(element: E, accumulator: A) -> A
13381423
if function.params().len() != 2 || function.params()[1].ty() != function.body().ty()
@@ -1343,11 +1428,7 @@ impl AbstractSyntaxTree for CallName {
13431428
}
13441429
}
13451430
parse::CallName::Fold(name, bound) => {
1346-
let function = scope
1347-
.get_function(name)
1348-
.cloned()
1349-
.ok_or(Error::FunctionUndefined(name.clone()))
1350-
.with_span(from)?;
1431+
let function = scope.get_function(name).cloned().with_span(from)?;
13511432
// A function that is used in a list fold has the signature:
13521433
// fn f(element: E, accumulator: A) -> A
13531434
if function.params().len() != 2 || function.params()[1].ty() != function.body().ty()
@@ -1358,11 +1439,7 @@ impl AbstractSyntaxTree for CallName {
13581439
}
13591440
}
13601441
parse::CallName::ForWhile(name) => {
1361-
let function = scope
1362-
.get_function(name)
1363-
.cloned()
1364-
.ok_or(Error::FunctionUndefined(name.clone()))
1365-
.with_span(from)?;
1442+
let function = scope.get_function(name).cloned().with_span(from)?;
13661443
// A function that is used in a for-while loop has the signature:
13671444
// fn f(accumulator: A, readonly_context: C, counter: u{N}) -> Either<B, A>
13681445
// where

src/driver/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
//! the dependency graph construction.
2929
3030
mod linearization;
31-
mod resolve_order;
31+
pub mod resolve_order;
3232

3333
use std::collections::{HashMap, HashSet, VecDeque};
3434
use std::path::PathBuf;

0 commit comments

Comments
 (0)