Skip to content
This repository was archived by the owner on Apr 2, 2026. It is now read-only.

Commit a54c493

Browse files
author
Joshua Barretto
committed
Added type solver to mini_ml
1 parent b80f34b commit a54c493

1 file changed

Lines changed: 186 additions & 32 deletions

File tree

examples/mini_ml.rs

Lines changed: 186 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ pub enum Token<'src> {
1515
Let,
1616
In,
1717
Fn,
18+
True,
19+
False,
1820
}
1921

2022
pub type Spanned<T> = (T, SimpleSpan);
@@ -26,6 +28,8 @@ fn lexer<'src>(
2628
"let" => Token::Let,
2729
"in" => Token::In,
2830
"fn" => Token::Fn,
31+
"true" => Token::True,
32+
"false" => Token::False,
2933
s => Token::Ident(s),
3034
});
3135

@@ -57,15 +61,26 @@ fn lexer<'src>(
5761
.collect()
5862
}
5963

60-
#[derive(Debug)]
64+
#[derive(Clone, Debug)]
6165
pub enum Expr<'src> {
62-
Local(&'src str),
66+
Var(&'src str),
6367
Num(f64),
64-
Let(Spanned<&'src str>, Box<Spanned<Self>>, Box<Spanned<Self>>),
68+
Bool(bool),
6569
Add(Box<Spanned<Self>>, Box<Spanned<Self>>),
6670
Mul(Box<Spanned<Self>>, Box<Spanned<Self>>),
67-
Call(Box<Spanned<Self>>, Box<Spanned<Self>>),
68-
Func(Vec<Spanned<&'src str>>, Box<Spanned<Self>>),
71+
Let {
72+
lhs: Spanned<&'src str>,
73+
rhs: Box<Spanned<Self>>,
74+
then: Box<Spanned<Self>>,
75+
},
76+
Apply {
77+
func: Box<Spanned<Self>>,
78+
arg: Box<Spanned<Self>>,
79+
},
80+
Func {
81+
arg: Box<Spanned<&'src str>>,
82+
body: Box<Spanned<Self>>,
83+
},
6984
}
7085

7186
type ParserInput<'src> = SpannedInput<Token<'src>, SimpleSpan, &'src [Spanned<Token<'src>>]>;
@@ -77,45 +92,173 @@ fn parser<'src>(
7792
let ident = select_ref! { Token::Ident(x) => *x };
7893
let atom = choice((
7994
select_ref! { Token::Num(x) => Expr::Num(*x) },
80-
ident.map(Expr::Local),
95+
just(Token::True).to(Expr::Bool(true)),
96+
just(Token::False).to(Expr::Bool(false)),
97+
ident.map(Expr::Var),
8198
// let x = y in z
8299
just(Token::Let)
83100
.ignore_then(ident.map_with(|x, e| (x, e.span())))
84101
.then_ignore(just(Token::Eq))
85102
.then(expr.clone())
86103
.then_ignore(just(Token::In))
87104
.then(expr.clone())
88-
.map(|((lhs, rhs), then)| Expr::Let(lhs, Box::new(rhs), Box::new(then))),
89-
// fn x y = z
90-
just(Token::Fn)
91-
.ignore_then(ident.map_with(|x, e| (x, e.span())).repeated().collect())
92-
.then_ignore(just(Token::Eq))
93-
.then(expr.clone())
94-
.map(|(args, body)| Expr::Func(args, Box::new(body))),
105+
.map(|((lhs, rhs), then)| Expr::Let {
106+
lhs,
107+
rhs: Box::new(rhs),
108+
then: Box::new(then),
109+
}),
95110
));
96111

97-
atom.map_with(|expr, e| (expr, e.span()))
112+
choice((
113+
atom.map_with(|expr, e| (expr, e.span())),
114+
// fn x y = z
115+
just(Token::Fn).ignore_then(
116+
ident.map_with(|x, e| (x, e.span())).repeated().foldr_with(
117+
just(Token::Eq).ignore_then(expr.clone()),
118+
|arg, body, e| {
119+
(
120+
Expr::Func {
121+
arg: Box::new(arg),
122+
body: Box::new(body),
123+
},
124+
e.span(),
125+
)
126+
},
127+
),
128+
),
98129
// ( x )
99-
.or(expr.nested_in(
130+
expr.nested_in(
100131
select_ref! { Token::Parens(ts) = e => ts.as_slice().spanned(e.span()) },
101-
))
102-
.pratt((
103-
// Multiply
104-
infix(left(10), just(Token::Asterisk), |x, _, y, e| {
105-
(Expr::Mul(Box::new(x), Box::new(y)), e.span())
106-
}),
107-
// Add
108-
infix(left(9), just(Token::Plus), |x, _, y, e| {
109-
(Expr::Add(Box::new(x), Box::new(y)), e.span())
110-
}),
111-
// Calls
112-
infix(left(1), empty(), |x, _, y, e| {
113-
(Expr::Call(Box::new(x), Box::new(y)), e.span())
114-
}),
115-
))
132+
),
133+
))
134+
.pratt((
135+
// Multiply
136+
infix(left(10), just(Token::Asterisk), |x, _, y, e| {
137+
(Expr::Mul(Box::new(x), Box::new(y)), e.span())
138+
}),
139+
// Add
140+
infix(left(9), just(Token::Plus), |x, _, y, e| {
141+
(Expr::Add(Box::new(x), Box::new(y)), e.span())
142+
}),
143+
// Calls
144+
infix(left(1), empty(), |x, _, y, e| {
145+
(
146+
Expr::Apply {
147+
func: Box::new(x),
148+
arg: Box::new(y),
149+
},
150+
e.span(),
151+
)
152+
}),
153+
))
116154
})
117155
}
118156

157+
#[derive(Copy, Clone, Debug, PartialEq)]
158+
struct TyVar(usize);
159+
160+
#[derive(Copy, Clone, Debug)]
161+
enum TyInfo {
162+
Unknown,
163+
Ref(TyVar),
164+
Num,
165+
Bool,
166+
Func(TyVar, TyVar),
167+
}
168+
169+
#[derive(Debug)]
170+
enum Ty {
171+
Num,
172+
Bool,
173+
Func(Box<Self>, Box<Self>),
174+
}
175+
176+
#[derive(Default)]
177+
struct Solver {
178+
vars: Vec<TyInfo>,
179+
}
180+
181+
impl Solver {
182+
fn create_ty(&mut self, info: TyInfo) -> TyVar {
183+
self.vars.push(info);
184+
TyVar(self.vars.len() - 1)
185+
}
186+
187+
fn unify(&mut self, a: TyVar, b: TyVar) {
188+
match (self.vars[a.0], self.vars[b.0]) {
189+
(TyInfo::Unknown, _) => self.vars[a.0] = TyInfo::Ref(b),
190+
(_, TyInfo::Unknown) => self.vars[b.0] = TyInfo::Ref(a),
191+
(TyInfo::Ref(a), _) => self.unify(a, b),
192+
(_, TyInfo::Ref(b)) => self.unify(a, b),
193+
(TyInfo::Num, TyInfo::Num) | (TyInfo::Bool, TyInfo::Bool) => {}
194+
(TyInfo::Func(a_i, a_o), TyInfo::Func(b_i, b_o)) => {
195+
self.unify(a_i, b_i);
196+
self.unify(a_o, b_o);
197+
}
198+
(a, b) => panic!("Type mismatch between {a:?} and {b:?}"),
199+
}
200+
}
201+
202+
fn check<'ast>(&mut self, expr: &Expr<'ast>, env: &mut Vec<(&'ast str, TyVar)>) -> TyVar {
203+
match expr {
204+
// Literal expressions are easy, their type doesn't need inferring.
205+
Expr::Num(_) => self.create_ty(TyInfo::Num),
206+
Expr::Bool(_) => self.create_ty(TyInfo::Bool),
207+
// We search the environment backward until we find a binding matching the variable name.
208+
Expr::Var(name) => {
209+
env.iter_mut()
210+
.rev()
211+
.find(|(n, _)| n == name)
212+
.expect("No such variable in scope")
213+
.1
214+
}
215+
// In a let expression, `rhs` gets bound with name `lhs` in the environment used to type-check `then`.
216+
Expr::Let { lhs, rhs, then } => {
217+
let rhs = self.check(&rhs.0, env);
218+
env.push((lhs.0, rhs));
219+
let out = self.check(&then.0, env);
220+
env.pop();
221+
out
222+
}
223+
// In a function, the argument becomes an unknown type in the environment used to type-check `body`.
224+
Expr::Func { arg, body } => {
225+
let arg_ty = self.create_ty(TyInfo::Unknown);
226+
env.push((arg.0, arg_ty));
227+
let body = self.check(&body.0, env);
228+
env.pop();
229+
self.create_ty(TyInfo::Func(arg_ty, body))
230+
}
231+
// During function application, both argument and function are type-checked and then we force the latter to be a function of the former.
232+
Expr::Apply { func, arg } => {
233+
let func = self.check(&func.0, env);
234+
let arg = self.check(&arg.0, env);
235+
let out = self.create_ty(TyInfo::Unknown);
236+
let func_ty = self.create_ty(TyInfo::Func(arg, out));
237+
self.unify(func_ty, func);
238+
out
239+
}
240+
Expr::Add(l, r) | Expr::Mul(l, r) => {
241+
let out = self.create_ty(TyInfo::Num);
242+
let l = self.check(&l.0, env);
243+
self.unify(out, l);
244+
let r = self.check(&r.0, env);
245+
self.unify(out, r);
246+
out
247+
}
248+
}
249+
}
250+
251+
pub fn solve(&self, var: TyVar) -> Ty {
252+
match self.vars[var.0] {
253+
TyInfo::Unknown => panic!("Cannot infer type"),
254+
TyInfo::Ref(var) => self.solve(var),
255+
TyInfo::Num => Ty::Num,
256+
TyInfo::Bool => Ty::Bool,
257+
TyInfo::Func(i, o) => Ty::Func(Box::new(self.solve(i)), Box::new(self.solve(o))),
258+
}
259+
}
260+
}
261+
119262
fn main() {
120263
let text = "
121264
let add = fn x y = x + y in
@@ -128,7 +271,18 @@ fn main() {
128271

129272
dbg!(&tokens);
130273

131-
let expr = parser().parse(tokens.spanned((0..text.len()).into()));
274+
let expr = parser()
275+
.parse(tokens.spanned((0..text.len()).into()))
276+
.unwrap();
277+
278+
dbg!(&expr);
279+
280+
let mut solver = Solver::default();
281+
282+
let program_ty = solver.check(&expr.0, &mut Vec::new());
132283

133-
dbg!(expr);
284+
println!(
285+
"The expression outputs type `{:?}`",
286+
solver.solve(program_ty)
287+
);
134288
}

0 commit comments

Comments
 (0)