You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -50,12 +50,13 @@ We will use the following imports:
50
50
51
51
```{code-cell} ipython3
52
52
import matplotlib.pyplot as plt
53
-
import numpy as np
54
-
from numba import jit, float64
55
-
from numba.experimental import jitclass
53
+
import jax
54
+
import jax.numpy as jnp
55
+
from typing import NamedTuple
56
+
import quantecon as qe
56
57
```
57
58
58
-
## The Algorithm
59
+
## The algorithm
59
60
60
61
The model is the same as the McCall model with job separation we {doc}`studied before <mccall_model_with_separation>`, except that the wage offer distribution is continuous.
61
62
@@ -91,7 +92,7 @@ The function $q$ in {eq}`bell1mcmc` is the density of the wage offer distributio
91
92
92
93
Its support is taken as equal to $\mathbb R_+$.
93
94
94
-
### Value Function Iteration
95
+
### Value function iteration
95
96
96
97
In theory, we should now proceed as follows:
97
98
@@ -111,7 +112,7 @@ is to record its value $v'(w)$ for every $w \in \mathbb R_+$.
111
112
112
113
Clearly, this is impossible.
113
114
114
-
### Fitted Value Function Iteration
115
+
### Fitted value function iteration
115
116
116
117
What we will do instead is use **fitted value function iteration**.
117
118
@@ -145,21 +146,21 @@ This method
145
146
{cite}`gordon1995stable` or {cite}`stachurski2008continuous`) and
146
147
1. preserves useful shape properties such as monotonicity and concavity/convexity.
147
148
148
-
Linear interpolation will be implemented using [numpy.interp](https://numpy.org/doc/stable/reference/generated/numpy.interp.html).
149
+
Linear interpolation will be implemented using JAX's interpolation function `jnp.interp`.
149
150
150
151
The next figure illustrates piecewise linear interpolation of an arbitrary
151
152
function on grid points $0, 0.2, 0.4, 0.6, 0.8, 1$.
152
153
153
154
```{code-cell} python3
154
155
def f(x):
155
-
y1 = 2 * np.cos(6 * x) + np.sin(14 * x)
156
+
y1 = 2 * jnp.cos(6 * x) + jnp.sin(14 * x)
156
157
return y1 + 2.5
157
158
158
-
c_grid = np.linspace(0, 1, 6)
159
-
f_grid = np.linspace(0, 1, 150)
159
+
c_grid = jnp.linspace(0, 1, 6)
160
+
f_grid = jnp.linspace(0, 1, 150)
160
161
161
162
def Af(x):
162
-
return np.interp(x, c_grid, f(c_grid))
163
+
return jnp.interp(x, c_grid, f(c_grid))
163
164
164
165
fig, ax = plt.subplots()
165
166
@@ -175,123 +176,122 @@ plt.show()
175
176
176
177
## Implementation
177
178
178
-
The first step is to build a jitted class for the McCall model with separation and
179
-
a continuous wage offer distribution.
179
+
The first step is to build a JAX-compatible structure for the McCall model with separation and a continuous wage offer distribution.
180
180
181
181
We will take the utility function to be the log function for this application, with $u(c) = \ln c$.
182
182
183
183
We will adopt the lognormal distribution for wages, with $w = \exp(\mu + \sigma z)$
184
184
when $z$ is standard normal and $\mu, \sigma$ are parameters.
0 commit comments