@@ -28,7 +28,7 @@ universe u v w u₀ u₁ v₀ v₁
2828structure MonadCont.Label (α : Type w) (m : Type u → Type v) (β : Type u) where
2929 apply : α → m β
3030
31- def MonadCont.goto {α β} {m : Type u → Type v} (f : MonadCont.Label α m β) (x : α) :=
31+ abbrev MonadCont.goto {α β} {m : Type u → Type v} (f : MonadCont.Label α m β) (x : α) :=
3232 f.apply x
3333
3434class MonadCont (m : Type u → Type v) where
@@ -58,8 +58,11 @@ export MonadCont (Label goto)
5858
5959variable {r : Type u} {m : Type u → Type v} {α β : Type w}
6060
61- def run : ContT r m α → (α → m r) → m r :=
62- id
61+ /-- Build a `ContT` from a function taking a continuation callback. -/
62+ def mk (f : (α → m r) → m r) : ContT r m α := f
63+
64+ /-- Run a `ContT` with a provided callback. -/
65+ def run (x : ContT r m α) : (α → m r) → m r := x
6366
6467def map (f : m r → m r) (x : ContT r m α) : ContT r m α :=
6568 f ∘ x
@@ -81,24 +84,55 @@ instance : Monad (ContT r m) where
8184 pure x f := f x
8285 bind x f g := x fun i => f i g
8386
87+ @[simp]
88+ theorem run_mk (f : (α → m r) → m r) (k : α → m r) : (.mk f : ContT r m α).run k = f k := rfl
89+
90+ @[simp]
91+ theorem run_pure (a : α) (k : α → m r) : (pure a : ContT r m α).run k = k a := rfl
92+
93+ @[simp]
94+ theorem run_bind (x : ContT r m α) (f : α → ContT r m β) (k : β → m r) :
95+ (x >>= f).run k = x.run fun x => (f x).run k := rfl
96+
97+ @[simp]
98+ theorem run_map (f : α → β) (x : ContT r m α) (k : β → m r) :
99+ (f <$> x).run k = x.run (k ∘ f) := rfl
100+
101+ @[simp]
102+ theorem run_seq (f : ContT r m (α → β)) (x : ContT r m α) (k : β → m r) :
103+ (f <*> x).run k = f.run fun f => x.run (k ∘ f) := rfl
104+
105+ @[simp]
106+ theorem run_seqLeft (x : ContT r m α) (y : ContT r m β) (k : α → m r) :
107+ (x <* y).run k = x.run fun x => y.run fun _ => k x := rfl
108+
109+ @[simp]
110+ theorem run_seqRight (x : ContT r m α) (y : ContT r m β) (k : β → m r) :
111+ (x *> y).run k = x.run fun _ => y.run k := rfl
112+
84113instance : LawfulMonad (ContT r m) := LawfulMonad.mk'
85114 (id_map := by intros; rfl)
86115 (pure_bind := by intros; ext; rfl)
87116 (bind_assoc := by intros; ext; rfl)
88117
89- def monadLift [Monad m] {α} : m α → ContT r m α := fun x f => x >>= f
90-
91118instance [Monad m] : MonadLift m (ContT r m) where
92- monadLift := ContT.monadLift
119+ monadLift x := .mk fun k => x >>= k
120+
121+ @[simp]
122+ theorem run_monadLift [Monad m] {α} (x : m α) (k : α → m r) :
123+ (monadLift x : ContT r m α).run k = x >>= k := rfl
93124
94125theorem monadLift_bind [Monad m] [LawfulMonad m] {α β} (x : m α) (f : α → m β) :
95126 (monadLift (x >>= f) : ContT r m β) = monadLift x >>= monadLift ∘ f := by
96127 ext
97- simp only [monadLift, (· ∘ ·), (· >>= ·), bind_assoc, id, run,
98- ContT.monadLift]
128+ simp only [bind_assoc, run_bind, run_monadLift, Function.comp_apply]
99129
100130instance : MonadCont (ContT r m) where
101- callCC f g := f ⟨fun x _ => g x⟩ g
131+ callCC f := .mk fun k => f ⟨fun x => .mk fun _ => k x⟩ k
132+
133+ @[simp]
134+ theorem run_callCC (f : Label α (ContT r m) β → ContT r m α) (k : α → m r) :
135+ (callCC f).run k = (f ⟨fun x => .mk fun _ => k x⟩).run k := rfl
102136
103137instance : LawfulMonadCont (ContT r m) where
104138 callCC_bind_right := by intros; ext; rfl
@@ -121,8 +155,8 @@ See [Zulip](https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topi
121155for further discussion.
122156-/
123157instance (ε) [MonadExceptOf ε m] : MonadExceptOf ε (ContT r m) where
124- throw e _ := throw e
125- tryCatch act h f := tryCatch (act.run f ) fun e => (h e).run f
158+ throw e := .mk fun _ => throw e
159+ tryCatch act h := .mk fun k => tryCatch (act.run k ) fun e => (h e).run k
126160
127161@[simp]
128162theorem run_throw {ε} [MonadExceptOf ε m]
0 commit comments