@@ -15,6 +15,8 @@ pub enum Token<'src> {
1515 Let ,
1616 In ,
1717 Fn ,
18+ True ,
19+ False ,
1820}
1921
2022pub type Spanned < T > = ( T , SimpleSpan ) ;
@@ -26,6 +28,8 @@ fn lexer<'src>(
2628 "let" => Token :: Let ,
2729 "in" => Token :: In ,
2830 "fn" => Token :: Fn ,
31+ "true" => Token :: True ,
32+ "false" => Token :: False ,
2933 s => Token :: Ident ( s) ,
3034 } ) ;
3135
@@ -57,15 +61,26 @@ fn lexer<'src>(
5761 . collect ( )
5862}
5963
60- #[ derive( Debug ) ]
64+ #[ derive( Clone , Debug ) ]
6165pub enum Expr < ' src > {
62- Local ( & ' src str ) ,
66+ Var ( & ' src str ) ,
6367 Num ( f64 ) ,
64- Let ( Spanned < & ' src str > , Box < Spanned < Self > > , Box < Spanned < Self > > ) ,
68+ Bool ( bool ) ,
6569 Add ( Box < Spanned < Self > > , Box < Spanned < Self > > ) ,
6670 Mul ( Box < Spanned < Self > > , Box < Spanned < Self > > ) ,
67- Call ( Box < Spanned < Self > > , Box < Spanned < Self > > ) ,
68- Func ( Vec < Spanned < & ' src str > > , Box < Spanned < Self > > ) ,
71+ Let {
72+ lhs : Spanned < & ' src str > ,
73+ rhs : Box < Spanned < Self > > ,
74+ then : Box < Spanned < Self > > ,
75+ } ,
76+ Apply {
77+ func : Box < Spanned < Self > > ,
78+ arg : Box < Spanned < Self > > ,
79+ } ,
80+ Func {
81+ arg : Box < Spanned < & ' src str > > ,
82+ body : Box < Spanned < Self > > ,
83+ } ,
6984}
7085
7186type ParserInput < ' src > = SpannedInput < Token < ' src > , SimpleSpan , & ' src [ Spanned < Token < ' src > > ] > ;
@@ -77,45 +92,173 @@ fn parser<'src>(
7792 let ident = select_ref ! { Token :: Ident ( x) => * x } ;
7893 let atom = choice ( (
7994 select_ref ! { Token :: Num ( x) => Expr :: Num ( * x) } ,
80- ident. map ( Expr :: Local ) ,
95+ just ( Token :: True ) . to ( Expr :: Bool ( true ) ) ,
96+ just ( Token :: False ) . to ( Expr :: Bool ( false ) ) ,
97+ ident. map ( Expr :: Var ) ,
8198 // let x = y in z
8299 just ( Token :: Let )
83100 . ignore_then ( ident. map_with ( |x, e| ( x, e. span ( ) ) ) )
84101 . then_ignore ( just ( Token :: Eq ) )
85102 . then ( expr. clone ( ) )
86103 . then_ignore ( just ( Token :: In ) )
87104 . then ( expr. clone ( ) )
88- . map ( |( ( lhs, rhs) , then) | Expr :: Let ( lhs, Box :: new ( rhs) , Box :: new ( then) ) ) ,
89- // fn x y = z
90- just ( Token :: Fn )
91- . ignore_then ( ident. map_with ( |x, e| ( x, e. span ( ) ) ) . repeated ( ) . collect ( ) )
92- . then_ignore ( just ( Token :: Eq ) )
93- . then ( expr. clone ( ) )
94- . map ( |( args, body) | Expr :: Func ( args, Box :: new ( body) ) ) ,
105+ . map ( |( ( lhs, rhs) , then) | Expr :: Let {
106+ lhs,
107+ rhs : Box :: new ( rhs) ,
108+ then : Box :: new ( then) ,
109+ } ) ,
95110 ) ) ;
96111
97- atom. map_with ( |expr, e| ( expr, e. span ( ) ) )
112+ choice ( (
113+ atom. map_with ( |expr, e| ( expr, e. span ( ) ) ) ,
114+ // fn x y = z
115+ just ( Token :: Fn ) . ignore_then (
116+ ident. map_with ( |x, e| ( x, e. span ( ) ) ) . repeated ( ) . foldr_with (
117+ just ( Token :: Eq ) . ignore_then ( expr. clone ( ) ) ,
118+ |arg, body, e| {
119+ (
120+ Expr :: Func {
121+ arg : Box :: new ( arg) ,
122+ body : Box :: new ( body) ,
123+ } ,
124+ e. span ( ) ,
125+ )
126+ } ,
127+ ) ,
128+ ) ,
98129 // ( x )
99- . or ( expr. nested_in (
130+ expr. nested_in (
100131 select_ref ! { Token :: Parens ( ts) = e => ts. as_slice( ) . spanned( e. span( ) ) } ,
101- ) )
102- . pratt ( (
103- // Multiply
104- infix ( left ( 10 ) , just ( Token :: Asterisk ) , |x, _, y, e| {
105- ( Expr :: Mul ( Box :: new ( x) , Box :: new ( y) ) , e. span ( ) )
106- } ) ,
107- // Add
108- infix ( left ( 9 ) , just ( Token :: Plus ) , |x, _, y, e| {
109- ( Expr :: Add ( Box :: new ( x) , Box :: new ( y) ) , e. span ( ) )
110- } ) ,
111- // Calls
112- infix ( left ( 1 ) , empty ( ) , |x, _, y, e| {
113- ( Expr :: Call ( Box :: new ( x) , Box :: new ( y) ) , e. span ( ) )
114- } ) ,
115- ) )
132+ ) ,
133+ ) )
134+ . pratt ( (
135+ // Multiply
136+ infix ( left ( 10 ) , just ( Token :: Asterisk ) , |x, _, y, e| {
137+ ( Expr :: Mul ( Box :: new ( x) , Box :: new ( y) ) , e. span ( ) )
138+ } ) ,
139+ // Add
140+ infix ( left ( 9 ) , just ( Token :: Plus ) , |x, _, y, e| {
141+ ( Expr :: Add ( Box :: new ( x) , Box :: new ( y) ) , e. span ( ) )
142+ } ) ,
143+ // Calls
144+ infix ( left ( 1 ) , empty ( ) , |x, _, y, e| {
145+ (
146+ Expr :: Apply {
147+ func : Box :: new ( x) ,
148+ arg : Box :: new ( y) ,
149+ } ,
150+ e. span ( ) ,
151+ )
152+ } ) ,
153+ ) )
116154 } )
117155}
118156
157+ #[ derive( Copy , Clone , Debug , PartialEq ) ]
158+ struct TyVar ( usize ) ;
159+
160+ #[ derive( Copy , Clone , Debug ) ]
161+ enum TyInfo {
162+ Unknown ,
163+ Ref ( TyVar ) ,
164+ Num ,
165+ Bool ,
166+ Func ( TyVar , TyVar ) ,
167+ }
168+
169+ #[ derive( Debug ) ]
170+ enum Ty {
171+ Num ,
172+ Bool ,
173+ Func ( Box < Self > , Box < Self > ) ,
174+ }
175+
176+ #[ derive( Default ) ]
177+ struct Solver {
178+ vars : Vec < TyInfo > ,
179+ }
180+
181+ impl Solver {
182+ fn create_ty ( & mut self , info : TyInfo ) -> TyVar {
183+ self . vars . push ( info) ;
184+ TyVar ( self . vars . len ( ) - 1 )
185+ }
186+
187+ fn unify ( & mut self , a : TyVar , b : TyVar ) {
188+ match ( self . vars [ a. 0 ] , self . vars [ b. 0 ] ) {
189+ ( TyInfo :: Unknown , _) => self . vars [ a. 0 ] = TyInfo :: Ref ( b) ,
190+ ( _, TyInfo :: Unknown ) => self . vars [ b. 0 ] = TyInfo :: Ref ( a) ,
191+ ( TyInfo :: Ref ( a) , _) => self . unify ( a, b) ,
192+ ( _, TyInfo :: Ref ( b) ) => self . unify ( a, b) ,
193+ ( TyInfo :: Num , TyInfo :: Num ) | ( TyInfo :: Bool , TyInfo :: Bool ) => { }
194+ ( TyInfo :: Func ( a_i, a_o) , TyInfo :: Func ( b_i, b_o) ) => {
195+ self . unify ( a_i, b_i) ;
196+ self . unify ( a_o, b_o) ;
197+ }
198+ ( a, b) => panic ! ( "Type mismatch between {a:?} and {b:?}" ) ,
199+ }
200+ }
201+
202+ fn check < ' ast > ( & mut self , expr : & Expr < ' ast > , env : & mut Vec < ( & ' ast str , TyVar ) > ) -> TyVar {
203+ match expr {
204+ // Literal expressions are easy, their type doesn't need inferring.
205+ Expr :: Num ( _) => self . create_ty ( TyInfo :: Num ) ,
206+ Expr :: Bool ( _) => self . create_ty ( TyInfo :: Bool ) ,
207+ // We search the environment backward until we find a binding matching the variable name.
208+ Expr :: Var ( name) => {
209+ env. iter_mut ( )
210+ . rev ( )
211+ . find ( |( n, _) | n == name)
212+ . expect ( "No such variable in scope" )
213+ . 1
214+ }
215+ // In a let expression, `rhs` gets bound with name `lhs` in the environment used to type-check `then`.
216+ Expr :: Let { lhs, rhs, then } => {
217+ let rhs = self . check ( & rhs. 0 , env) ;
218+ env. push ( ( lhs. 0 , rhs) ) ;
219+ let out = self . check ( & then. 0 , env) ;
220+ env. pop ( ) ;
221+ out
222+ }
223+ // In a function, the argument becomes an unknown type in the environment used to type-check `body`.
224+ Expr :: Func { arg, body } => {
225+ let arg_ty = self . create_ty ( TyInfo :: Unknown ) ;
226+ env. push ( ( arg. 0 , arg_ty) ) ;
227+ let body = self . check ( & body. 0 , env) ;
228+ env. pop ( ) ;
229+ self . create_ty ( TyInfo :: Func ( arg_ty, body) )
230+ }
231+ // During function application, both argument and function are type-checked and then we force the latter to be a function of the former.
232+ Expr :: Apply { func, arg } => {
233+ let func = self . check ( & func. 0 , env) ;
234+ let arg = self . check ( & arg. 0 , env) ;
235+ let out = self . create_ty ( TyInfo :: Unknown ) ;
236+ let func_ty = self . create_ty ( TyInfo :: Func ( arg, out) ) ;
237+ self . unify ( func_ty, func) ;
238+ out
239+ }
240+ Expr :: Add ( l, r) | Expr :: Mul ( l, r) => {
241+ let out = self . create_ty ( TyInfo :: Num ) ;
242+ let l = self . check ( & l. 0 , env) ;
243+ self . unify ( out, l) ;
244+ let r = self . check ( & r. 0 , env) ;
245+ self . unify ( out, r) ;
246+ out
247+ }
248+ }
249+ }
250+
251+ pub fn solve ( & self , var : TyVar ) -> Ty {
252+ match self . vars [ var. 0 ] {
253+ TyInfo :: Unknown => panic ! ( "Cannot infer type" ) ,
254+ TyInfo :: Ref ( var) => self . solve ( var) ,
255+ TyInfo :: Num => Ty :: Num ,
256+ TyInfo :: Bool => Ty :: Bool ,
257+ TyInfo :: Func ( i, o) => Ty :: Func ( Box :: new ( self . solve ( i) ) , Box :: new ( self . solve ( o) ) ) ,
258+ }
259+ }
260+ }
261+
119262fn main ( ) {
120263 let text = "
121264 let add = fn x y = x + y in
@@ -128,7 +271,18 @@ fn main() {
128271
129272 dbg ! ( & tokens) ;
130273
131- let expr = parser ( ) . parse ( tokens. spanned ( ( 0 ..text. len ( ) ) . into ( ) ) ) ;
274+ let expr = parser ( )
275+ . parse ( tokens. spanned ( ( 0 ..text. len ( ) ) . into ( ) ) )
276+ . unwrap ( ) ;
277+
278+ dbg ! ( & expr) ;
279+
280+ let mut solver = Solver :: default ( ) ;
281+
282+ let program_ty = solver. check ( & expr. 0 , & mut Vec :: new ( ) ) ;
132283
133- dbg ! ( expr) ;
284+ println ! (
285+ "The expression outputs type `{:?}`" ,
286+ solver. solve( program_ty)
287+ ) ;
134288}
0 commit comments