Skip to content

Commit 8260f21

Browse files
committed
Update mccall_fitted_vfi.md
1 parent a8335e6 commit 8260f21

1 file changed

Lines changed: 119 additions & 119 deletions

File tree

lectures/mccall_fitted_vfi.md

Lines changed: 119 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,13 @@ We will use the following imports:
5050

5151
```{code-cell} ipython3
5252
import matplotlib.pyplot as plt
53-
import numpy as np
54-
from numba import jit, float64
55-
from numba.experimental import jitclass
53+
import jax
54+
import jax.numpy as jnp
55+
from typing import NamedTuple
56+
import quantecon as qe
5657
```
5758

58-
## The Algorithm
59+
## The algorithm
5960

6061
The model is the same as the McCall model with job separation we {doc}`studied before <mccall_model_with_separation>`, except that the wage offer distribution is continuous.
6162

@@ -91,7 +92,7 @@ The function $q$ in {eq}`bell1mcmc` is the density of the wage offer distributio
9192

9293
Its support is taken as equal to $\mathbb R_+$.
9394

94-
### Value Function Iteration
95+
### Value function iteration
9596

9697
In theory, we should now proceed as follows:
9798

@@ -111,7 +112,7 @@ is to record its value $v'(w)$ for every $w \in \mathbb R_+$.
111112

112113
Clearly, this is impossible.
113114

114-
### Fitted Value Function Iteration
115+
### Fitted value function iteration
115116

116117
What we will do instead is use **fitted value function iteration**.
117118

@@ -145,21 +146,21 @@ This method
145146
{cite}`gordon1995stable` or {cite}`stachurski2008continuous`) and
146147
1. preserves useful shape properties such as monotonicity and concavity/convexity.
147148

148-
Linear interpolation will be implemented using [numpy.interp](https://numpy.org/doc/stable/reference/generated/numpy.interp.html).
149+
Linear interpolation will be implemented using JAX's interpolation function `jnp.interp`.
149150

150151
The next figure illustrates piecewise linear interpolation of an arbitrary
151152
function on grid points $0, 0.2, 0.4, 0.6, 0.8, 1$.
152153

153154
```{code-cell} python3
154155
def f(x):
155-
y1 = 2 * np.cos(6 * x) + np.sin(14 * x)
156+
y1 = 2 * jnp.cos(6 * x) + jnp.sin(14 * x)
156157
return y1 + 2.5
157158
158-
c_grid = np.linspace(0, 1, 6)
159-
f_grid = np.linspace(0, 1, 150)
159+
c_grid = jnp.linspace(0, 1, 6)
160+
f_grid = jnp.linspace(0, 1, 150)
160161
161162
def Af(x):
162-
return np.interp(x, c_grid, f(c_grid))
163+
return jnp.interp(x, c_grid, f(c_grid))
163164
164165
fig, ax = plt.subplots()
165166
@@ -175,123 +176,122 @@ plt.show()
175176

176177
## Implementation
177178

178-
The first step is to build a jitted class for the McCall model with separation and
179-
a continuous wage offer distribution.
179+
The first step is to build a JAX-compatible structure for the McCall model with separation and a continuous wage offer distribution.
180180

181181
We will take the utility function to be the log function for this application, with $u(c) = \ln c$.
182182

183183
We will adopt the lognormal distribution for wages, with $w = \exp(\mu + \sigma z)$
184184
when $z$ is standard normal and $\mu, \sigma$ are parameters.
185185

186186
```{code-cell} python3
187-
@jit
188187
def lognormal_draws(n=1000, μ=2.5, σ=0.5, seed=1234):
189-
np.random.seed(seed)
190-
z = np.random.randn(n)
191-
w_draws = np.exp(μ + σ * z)
188+
key = jax.random.PRNGKey(seed)
189+
z = jax.random.normal(key, (n,))
190+
w_draws = jnp.exp(μ + σ * z)
192191
return w_draws
193192
```
194193

195-
Here's our class.
194+
Here's our model structure using a NamedTuple.
196195

197196
```{code-cell} python3
198-
mccall_data_continuous = [
199-
('c', float64), # unemployment compensation
200-
('α', float64), # job separation rate
201-
('β', float64), # discount factor
202-
('w_grid', float64[:]), # grid of points for fitted VFI
203-
('w_draws', float64[:]) # draws of wages for Monte Carlo
204-
]
205-
206-
@jitclass(mccall_data_continuous)
207-
class McCallModelContinuous:
208-
209-
def __init__(self,
210-
c=1,
211-
α=0.1,
212-
β=0.96,
213-
grid_min=1e-10,
214-
grid_max=5,
215-
grid_size=100,
216-
w_draws=lognormal_draws()):
217-
218-
self.c, self.α, self.β = c, α, β
219-
220-
self.w_grid = np.linspace(grid_min, grid_max, grid_size)
221-
self.w_draws = w_draws
222-
223-
def update(self, v, d):
224-
225-
# Simplify names
226-
c, α, β = self.c, self.α, self.β
227-
w = self.w_grid
228-
u = lambda x: np.log(x)
229-
230-
# Interpolate array represented value function
231-
vf = lambda x: np.interp(x, w, v)
232-
233-
# Update d using Monte Carlo to evaluate integral
234-
d_new = np.mean(np.maximum(vf(self.w_draws), u(c) + β * d))
235-
236-
# Update v
237-
v_new = u(w) + β * ((1 - α) * v + α * d)
238-
239-
return v_new, d_new
197+
class McCallModelContinuous(NamedTuple):
198+
c: float # unemployment compensation
199+
α: float # job separation rate
200+
β: float # discount factor
201+
w_grid: jnp.ndarray # grid of points for fitted VFI
202+
w_draws: jnp.ndarray # draws of wages for Monte Carlo
203+
204+
def create_mccall_model(c=1,
205+
α=0.1,
206+
β=0.96,
207+
grid_min=1e-10,
208+
grid_max=5,
209+
grid_size=100,
210+
μ=2.5,
211+
σ=0.5,
212+
mc_size=1000,
213+
seed=1234):
214+
"""Factory function to create a McCall model instance."""
215+
w_draws = lognormal_draws(n=mc_size, μ=μ, σ=σ, seed=seed)
216+
w_grid = jnp.linspace(grid_min, grid_max, grid_size)
217+
return McCallModelContinuous(c=c, α=α, β=β, w_grid=w_grid, w_draws=w_draws)
218+
219+
@jax.jit
220+
def update(model, v, d):
221+
"""Update value function and continuation value."""
222+
# Unpack model parameters
223+
c, α, β, w_grid, w_draws = model
224+
u = jnp.log
225+
226+
# Interpolate array represented value function
227+
vf = lambda x: jnp.interp(x, w_grid, v)
228+
229+
# Update d using Monte Carlo to evaluate integral
230+
d_new = jnp.mean(jnp.maximum(vf(w_draws), u(c) + β * d))
231+
232+
# Update v
233+
v_new = u(w_grid) + β * ((1 - α) * v + α * d)
234+
235+
return v_new, d_new
240236
```
241237

242238
We then return the current iterate as an approximate solution.
243239

244240
```{code-cell} python3
245-
@jit
246-
def solve_model(mcm, tol=1e-5, max_iter=2000):
241+
@jax.jit
242+
def solve_model(model, tol=1e-5, max_iter=2000):
247243
"""
248244
Iterates to convergence on the Bellman equations
249245
250-
* mcm is an instance of McCallModel
246+
* model is an instance of McCallModelContinuous
251247
"""
252-
253-
v = np.ones_like(mcm.w_grid) # Initial guess of v
254-
d = 1 # Initial guess of d
255-
i = 0
256-
error = tol + 1
257-
258-
while error > tol and i < max_iter:
259-
v_new, d_new = mcm.update(v, d)
260-
error_1 = np.max(np.abs(v_new - v))
261-
error_2 = np.abs(d_new - d)
262-
error = max(error_1, error_2)
263-
v = v_new
264-
d = d_new
265-
i += 1
266-
267-
return v, d
248+
249+
# Initial guesses
250+
v = jnp.ones_like(model.w_grid)
251+
d = 1.0
252+
253+
def body_fun(state):
254+
v, d, i, error = state
255+
v_new, d_new = update(model, v, d)
256+
error_1 = jnp.max(jnp.abs(v_new - v))
257+
error_2 = jnp.abs(d_new - d)
258+
error = jnp.maximum(error_1, error_2)
259+
return v_new, d_new, i + 1, error
260+
261+
def cond_fun(state):
262+
v, d, i, error = state
263+
return (error > tol) & (i < max_iter)
264+
265+
initial_state = (v, d, 0, tol + 1)
266+
v_final, d_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state)
267+
268+
return v_final, d_final
268269
```
269270

270271
Here's a function `compute_reservation_wage` that takes an instance of `McCallModelContinuous`
271272
and returns the associated reservation wage.
272273

273-
If $v(w) < h$ for all $w$, then the function returns np.inf
274+
If $v(w) < h$ for all $w$, then the function returns `jnp.inf`
274275

275276
```{code-cell} python3
276-
@jit
277-
def compute_reservation_wage(mcm):
277+
@jax.jit
278+
def compute_reservation_wage(model):
278279
"""
279280
Computes the reservation wage of an instance of the McCall model
280281
by finding the smallest w such that v(w) >= h.
281282
282-
If no such w exists, then w_bar is set to np.inf.
283+
If no such w exists, then w_bar is set to inf.
283284
"""
284-
u = lambda x: np.log(x)
285-
286-
v, d = solve_model(mcm)
287-
h = u(mcm.c) + mcm.β * d
288-
289-
w_bar = np.inf
290-
for i, wage in enumerate(mcm.w_grid):
291-
if v[i] > h:
292-
w_bar = wage
293-
break
294-
285+
c, α, β, w_grid, w_draws = model
286+
u = jnp.log
287+
288+
v, d = solve_model(model)
289+
h = u(c) + β * d
290+
291+
# Find the first wage where v(w) >= h
292+
indices = jnp.where(v >= h, size=1, fill_value=-1)
293+
w_bar = jnp.where(indices[0] >= 0, w_grid[indices[0]], jnp.inf)
294+
295295
return w_bar
296296
```
297297

@@ -305,7 +305,7 @@ The exercises ask you to explore the solution and how it changes with parameters
305305
Use the code above to explore what happens to the reservation wage when the wage parameter $\mu$
306306
changes.
307307
308-
Use the default parameters and $\mu$ in `mu_vals = np.linspace(0.0, 2.0, 15)`.
308+
Use the default parameters and $\mu$ in `mu_vals = jnp.linspace(0.0, 2.0, 15)`.
309309
310310
Is the impact on the reservation wage as you expected?
311311
```
@@ -317,21 +317,18 @@ Is the impact on the reservation wage as you expected?
317317
Here is one solution
318318

319319
```{code-cell} python3
320-
mcm = McCallModelContinuous()
321-
mu_vals = np.linspace(0.0, 2.0, 15)
322-
w_bar_vals = np.empty_like(mu_vals)
323-
324-
fig, ax = plt.subplots()
320+
def compute_res_wage_given_mu(μ):
321+
model = create_mccall_model(μ=μ)
322+
w_bar = compute_reservation_wage(model)
323+
return w_bar
325324
326-
for i, m in enumerate(mu_vals):
327-
mcm.w_draws = lognormal_draws(μ=m)
328-
w_bar = compute_reservation_wage(mcm)
329-
w_bar_vals[i] = w_bar
325+
mu_vals = jnp.linspace(0.0, 2.0, 15)
326+
w_bar_vals = jax.vmap(compute_res_wage_given_mu)(mu_vals)
330327
328+
fig, ax = plt.subplots()
331329
ax.set(xlabel='mean', ylabel='reservation wage')
332330
ax.plot(mu_vals, w_bar_vals, label=r'$\bar w$ as a function of $\mu$')
333331
ax.legend()
334-
335332
plt.show()
336333
```
337334

@@ -354,11 +351,11 @@ support.
354351
355352
(This is a form of *mean-preserving spread*.)
356353
357-
Use `s_vals = np.linspace(1.0, 2.0, 15)` and `m = 2.0`.
354+
Use `s_vals = jnp.linspace(1.0, 2.0, 15)` and `m = 2.0`.
358355
359356
State how you expect the reservation wage to vary with $s$.
360357
361-
Now compute it. Is this as you expected?
358+
Now compute it - is this as you expected?
362359
```
363360

364361
```{solution-start} mfv_ex2
@@ -368,23 +365,26 @@ Now compute it. Is this as you expected?
368365
Here is one solution
369366

370367
```{code-cell} python3
371-
mcm = McCallModelContinuous()
372-
s_vals = np.linspace(1.0, 2.0, 15)
373-
m = 2.0
374-
w_bar_vals = np.empty_like(s_vals)
375-
376-
fig, ax = plt.subplots()
377-
378-
for i, s in enumerate(s_vals):
368+
def compute_res_wage_given_s(s, m=2.0, seed=1234):
379369
a, b = m - s, m + s
380-
mcm.w_draws = np.random.uniform(low=a, high=b, size=10_000)
381-
w_bar = compute_reservation_wage(mcm)
382-
w_bar_vals[i] = w_bar
370+
key = jax.random.PRNGKey(seed)
371+
uniform_draws = jax.random.uniform(key, shape=(10_000,), minval=a, maxval=b)
372+
# Create model with default parameters but replace wage draws
373+
model = create_mccall_model()
374+
model = model._replace(w_draws=uniform_draws)
375+
w_bar = compute_reservation_wage(model)
376+
return w_bar
383377
378+
s_vals = jnp.linspace(1.0, 2.0, 15)
379+
# Use vmap with different seeds for each s value
380+
seeds = jnp.arange(len(s_vals))
381+
compute_vectorized = jax.vmap(compute_res_wage_given_s, in_axes=(0, None, 0))
382+
w_bar_vals = compute_vectorized(s_vals, 2.0, seeds)
383+
384+
fig, ax = plt.subplots()
384385
ax.set(xlabel='volatility', ylabel='reservation wage')
385386
ax.plot(s_vals, w_bar_vals, label=r'$\bar w$ as a function of wage volatility')
386387
ax.legend()
387-
388388
plt.show()
389389
```
390390

0 commit comments

Comments
 (0)