Skip to content

Commit 2cfda1a

Browse files
committed
jit computation functions
1 parent 211ec24 commit 2cfda1a

1 file changed

Lines changed: 84 additions & 65 deletions

File tree

lectures/markov_asset.md

Lines changed: 84 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ An asset is a claim on one or more future payoffs.
6363

6464
The spot price of an asset depends primarily on
6565

66-
* the anticipated income stream
66+
* the anticipated income stream
6767
* attitudes about risk
6868
* rates of time preference
6969

@@ -75,15 +75,14 @@ We also look at creating and pricing *derivative* assets that repackage income s
7575

7676
Key tools for the lecture are
7777

78-
* Markov processses
78+
* Markov processes
7979
* formulas for predicting future values of functions of a Markov state
8080
* a formula for predicting the discounted sum of future values of a Markov state
8181

8282
Let's start with some imports:
8383

8484
```{code-cell} ipython
8585
import matplotlib.pyplot as plt
86-
import numpy as np
8786
import quantecon as qe
8887
import jax
8988
import jax.numpy as jnp
@@ -151,7 +150,7 @@ for some **stochastic discount factor** $m_{t+1}$.
151150

152151
Here the fixed discount factor $\beta$ in {eq}`rnapex` has been replaced by the random variable $m_{t+1}$.
153152

154-
How anticipated future payoffs are evaluated now depends on statistical properties of $m_{t+1}$.
153+
How anticipated future payoffs are evaluated now depends on statistical properties of $m_{t+1}$.
155154

156155
The stochastic discount factor can be specified to capture the idea that assets that tend to have good payoffs in bad states of the world are valued more highly than other assets whose payoffs don't behave that way.
157156

@@ -177,12 +176,12 @@ If we apply this definition to the asset pricing equation {eq}`lteeqs0` we obtai
177176
p_t = {\mathbb E}_t m_{t+1} {\mathbb E}_t (d_{t+1} + p_{t+1}) + {\rm cov}_t (m_{t+1}, d_{t+1}+ p_{t+1})
178177
```
179178

180-
It is useful to regard equation {eq}`lteeqs102` as a generalization of equation {eq}`rnapex`
179+
It is useful to regard equation {eq}`lteeqs102` as a generalization of equation {eq}`rnapex`
181180

182-
* In equation {eq}`rnapex`, the stochastic discount factor $m_{t+1} = \beta$, a constant.
181+
* In equation {eq}`rnapex`, the stochastic discount factor $m_{t+1} = \beta$, a constant.
183182
* In equation {eq}`rnapex`, the covariance term ${\rm cov}_t (m_{t+1}, d_{t+1}+ p_{t+1})$ is zero because $m_{t+1} = \beta$.
184183
* In equation {eq}`rnapex`, ${\mathbb E}_t m_{t+1}$ can be interpreted as the reciprocal of the one-period risk-free gross interest rate.
185-
* When $m_{t+1}$ covaries more negatively with the payout $p_{t+1} + d_{t+1}$, the price of the asset is lower.
184+
* When $m_{t+1}$ covaries more negatively with the payout $p_{t+1} + d_{t+1}$, the price of the asset is lower.
186185

187186
Equation {eq}`lteeqs102` asserts that the covariance of the stochastic discount factor with the one period payout $d_{t+1} + p_{t+1}$ is an important determinant of the price $p_t$.
188187

@@ -213,9 +212,9 @@ The answer to this question depends on
213212
1. the process we specify for dividends
214213
1. the stochastic discount factor and how it correlates with dividends
215214

216-
For now we'll study the risk-neutral case in which the stochastic discount factor is constant.
215+
For now we'll study the risk-neutral case in which the stochastic discount factor is constant.
217216

218-
We'll focus on how an asset price depends on a dividend process.
217+
We'll focus on how an asset price depends on a dividend process.
219218

220219
### Example 1: constant dividends
221220

@@ -340,7 +339,7 @@ x_series = mc.simulate(sim_length, init=jnp.median(mc.state_values))
340339
g_series = jnp.exp(x_series)
341340
d_series = jnp.cumprod(g_series) # Assumes d_0 = 1
342341
343-
series = [x_series, g_series, d_series, np.log(d_series)]
342+
series = [x_series, g_series, d_series, jnp.log(d_series)]
344343
labels = ['$X_t$', '$g_t$', '$d_t$', r'$\log \, d_t$']
345344
346345
fig, axes = plt.subplots(2, 2)
@@ -564,8 +563,8 @@ Assuming that the spectral radius of $J$ is strictly less than $\beta^{-1}$, thi
564563
v = (I - \beta J)^{-1} \beta J {\mathbb 1}
565564
```
566565
567-
We will define a function tree_price to compute $v$ given parameters stored in
568-
the class AssetPriceModel
566+
We will define a function `tree_price` to compute $v$ given parameters stored in
567+
the class `AssetPriceModel`
569568
570569
```{code-cell} ipython3
571570
class MarkovChain(NamedTuple):
@@ -578,8 +577,8 @@ class MarkovChain(NamedTuple):
578577
state_values : jnp.ndarray
579578
The values associated with each state
580579
"""
581-
P: jnp.ndarray
582-
state_values: jnp.ndarray
580+
P: jax.Array
581+
state_values: jax.Array
583582
584583
585584
class AssetPriceModel(NamedTuple):
@@ -590,20 +589,17 @@ class AssetPriceModel(NamedTuple):
590589
----------
591590
mc : MarkovChain
592591
Contains the transition matrix and set of state values
593-
g : callable
594-
The function mapping states to growth rates
592+
G : jax.Array
593+
The vector form of the function mapping states to growth rates
595594
β : float
596595
Discount factor
597596
γ : float
598597
Coefficient of risk aversion
599-
n: int
600-
The number of states
601598
"""
602599
mc: MarkovChain
603-
g: callable
600+
G: jax.Array
604601
β: float
605602
γ: float
606-
n: int
607603
608604
609605
def create_ap_model(g=jnp.exp, β=0.96, γ=2.0):
@@ -612,15 +608,16 @@ def create_ap_model(g=jnp.exp, β=0.96, γ=2.0):
612608
qe_mc = qe.tauchen(n, ρ, σ)
613609
P = jnp.array(qe_mc.P)
614610
state_values = jnp.array(qe_mc.state_values)
611+
G = g(state_values)
615612
mc = MarkovChain(P=P, state_values=state_values)
616613
617-
return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n)
614+
return AssetPriceModel(mc=mc, G=G, β=β, γ=γ)
618615
619616
620617
def create_customized_ap_model(mc: MarkovChain, g=jnp.exp, β=0.96, γ=2.0):
621618
"""Create an AssetPriceModel class using a customized Markov chain."""
622-
n = mc.P.shape[0]
623-
return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n)
619+
G = g(mc.state_values)
620+
return AssetPriceModel(mc=mc, G=G, β=β, γ=γ)
624621
625622
626623
def test_stability(Q, β):
@@ -633,9 +630,6 @@ def test_stability(Q, β):
633630
return sr
634631
635632
636-
# Wrap the check function to be JIT-safe
637-
test_stability = checkify.checkify(test_stability, errors=checkify.user_checks)
638-
639633
def tree_price(ap):
640634
"""
641635
Computes the price-dividend ratio of the Lucas tree.
@@ -649,22 +643,24 @@ def tree_price(ap):
649643
-------
650644
v : array_like(float)
651645
Lucas tree price-dividend ratio
652-
653646
"""
654647
# Simplify names, set up matrices
655-
β, γ, P, y = ap.β, ap.γ, ap.mc.P, ap.mc.state_values
656-
J = P * ap.g(y)**(1 - γ)
648+
β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G
649+
J = P * G**(1 - γ)
657650
658651
# Make sure that a unique solution exists
659-
err, out = test_stability(J, β)
660-
err.throw()
652+
test_stability(J, β)
661653
662654
# Compute v
663-
I = jnp.identity(ap.n)
664-
Ones = jnp.ones(ap.n)
655+
n = J.shape[0]
656+
I = jnp.identity(n)
657+
Ones = jnp.ones(n)
665658
v = solve(I - β * J, β * J @ Ones)
666659
667660
return v
661+
662+
# Wrap the function to be safely jitted
663+
tree_price_jit = jax.jit(checkify.checkify(tree_price))
668664
```
669665
670666
Here's a plot of $v$ as a function of the state for several values of $\gamma$,
@@ -685,8 +681,12 @@ states = ap.mc.state_values
685681
fig, ax = plt.subplots()
686682
687683
for γ in γs:
688-
tem_ap = create_customized_ap_model(mc=ap.mc, β=ap.β, γ=γ)
689-
v = tree_price(tem_ap)
684+
tem_ap = create_customized_ap_model(ap.mc, γ=γ)
685+
# checkify returns a tuple
686+
# err indicates whether errors happened
687+
err, v = tree_price_jit(tem_ap)
688+
# Stop if errors raised
689+
err.throw()
690690
ax.plot(states, v, lw=2, alpha=0.6, label=rf"$\gamma = {γ}$")
691691
692692
ax.set_ylabel("price-dividend ratio")
@@ -766,7 +766,7 @@ yields the solution
766766
p = (I - \beta M)^{-1} \beta M \zeta {\mathbb 1}
767767
```
768768
769-
The above is implemented in the function consol_price.
769+
The above is implemented in the function `consol_price`.
770770
771771
```{code-cell} ipython3
772772
def consol_price(ap, ζ):
@@ -787,19 +787,22 @@ def consol_price(ap, ζ):
787787
Console bond prices
788788
"""
789789
# Simplify names, set up matrices
790-
β, γ, P, y = ap.β, ap.γ, ap.mc.P, ap.mc.state_values
791-
M = P * ap.g(y)**(- γ)
790+
β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G
791+
M = P * G**(- γ)
792792
793793
# Make sure that a unique solution exists
794-
err, _ = test_stability(M, β)
795-
err.throw()
794+
test_stability(M, β)
796795
797796
# Compute price
798-
I = jnp.identity(ap.n)
799-
Ones = jnp.ones(ap.n)
797+
n = M.shape[0]
798+
I = jnp.identity(n)
799+
Ones = jnp.ones(n)
800800
p = solve(I - β * M, β * ζ * M @ Ones)
801801
802802
return p
803+
804+
# Wrap the function to be safely jitted
805+
consol_price_jit = jax.jit(checkify.checkify(consol_price))
803806
```
804807
805808
### Pricing an Option to Purchase the Consol
@@ -870,9 +873,9 @@ T w
870873
= \max \{ \beta M w,\; p - p_S {\mathbb 1} \}
871874
$$
872875
873-
Start at some initial $w$ and iterate with $T$ to convergence .
876+
Start at some initial $w$ and iterate with $T$ to convergence.
874877
875-
We can find the solution with the following function call_option
878+
We can find the solution with the following function `call_option`
876879
877880
```{code-cell} ipython3
878881
def call_option(ap, ζ, p_s, ϵ=1e-7):
@@ -900,16 +903,17 @@ def call_option(ap, ζ, p_s, ϵ=1e-7):
900903
901904
"""
902905
# Simplify names, set up matrices
903-
β, γ, P, y = ap.β, ap.γ, ap.mc.P, ap.mc.state_values
904-
M = P * ap.g(y)**(- γ)
906+
β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G
907+
M = P * G**(- γ)
905908
906909
# Make sure that a unique consol price exists
907-
err, _ = test_stability(M, β)
908-
err.throw()
910+
test_stability(M, β)
909911
910912
# Compute option price
911913
p = consol_price(ap, ζ)
912-
w = jnp.zeros(ap.n)
914+
err.throw()
915+
n = M.shape[0]
916+
w = jnp.zeros(n)
913917
error = ϵ + 1
914918
915919
def step(state):
@@ -928,6 +932,8 @@ def call_option(ap, ζ, p_s, ϵ=1e-7):
928932
final_w, _ = jax.lax.while_loop(cond, step, (w, error))
929933
930934
return final_w
935+
936+
call_option_jit = jax.jit(checkify.checkify(call_option))
931937
```
932938
933939
Here's a plot of $w$ compared to the consol price when $P_S = 40$
@@ -945,8 +951,10 @@ ap = create_ap_model(β=0.9)
945951
strike_price = 40
946952
947953
x = ap.mc.state_values
948-
p = consol_price(ap, ζ)
949-
w = call_option(ap, ζ, strike_price)
954+
err, p = consol_price_jit(ap, ζ)
955+
err.throw()
956+
err, w = call_option_jit(ap, ζ, strike_price)
957+
err.throw()
950958
951959
fig, ax = plt.subplots()
952960
ax.plot(x, p, 'b-', lw=2, label='consol price')
@@ -1046,7 +1054,7 @@ P = P.at[jnp.arange(n), jnp.arange(n)].set(
10461054
P[jnp.arange(n), jnp.arange(n)] + 1 - P.sum(1)
10471055
)
10481056
# State values of the Markov chain
1049-
s = np.array([0.95, 0.975, 1.0, 1.025, 1.05])
1057+
s = jnp.array([0.95, 0.975, 1.0, 1.025, 1.05])
10501058
γ = 2.0
10511059
β = 0.94
10521060
```
@@ -1076,7 +1084,7 @@ P = P.at[jnp.arange(n), jnp.arange(n)].set(
10761084
P[jnp.arange(n), jnp.arange(n)] + 1 - P.sum(1)
10771085
)
10781086
s = jnp.array([0.95, 0.975, 1.0, 1.025, 1.05]) # State values
1079-
mc = qe.MarkovChain(P, state_values=s)
1087+
mc = MarkovChain(P=P, state_values=s)
10801088
10811089
γ = 2.0
10821090
β = 0.94
@@ -1094,23 +1102,29 @@ apm = create_customized_ap_model(mc=mc, g=lambda x: x, β=β, γ=γ)
10941102
Now we just need to call the relevant functions on the data:
10951103
10961104
```{code-cell} ipython3
1097-
tree_price(apm)
1105+
err, v = tree_price_jit(apm)
1106+
err.throw()
1107+
print(v)
10981108
```
10991109
11001110
```{code-cell} ipython3
1101-
consol_price(apm, ζ)
1111+
err, p = consol_price_jit(apm, ζ)
1112+
err.throw()
1113+
print(p)
11021114
```
11031115
11041116
```{code-cell} ipython3
1105-
call_option(apm, ζ, p_s)
1117+
err, w = call_option_jit(apm, ζ, p_s)
1118+
err.throw()
1119+
print(w)
11061120
```
11071121
11081122
Let's show the last two functions as a plot
11091123
11101124
```{code-cell} ipython3
11111125
fig, ax = plt.subplots()
1112-
ax.plot(s, consol_price(apm, ζ), label='consol')
1113-
ax.plot(s, call_option(apm, ζ, p_s), label='call option')
1126+
ax.plot(s, p, label='consol')
1127+
ax.plot(s, w, label='call option')
11141128
ax.legend()
11151129
plt.show()
11161130
```
@@ -1169,36 +1183,41 @@ Is one higher than the other? Can you give intuition?
11691183
Here's a suitable function:
11701184
11711185
```{code-cell} ipython3
1172-
def finite_horizon_call_option(ap, ζ, p_s, k):
1186+
def finite_call_option(ap, ζ, p_s, k):
11731187
"""
11741188
Computes k period option value.
11751189
"""
11761190
# Simplify names, set up matrices
1177-
β, γ, P, y = ap.β, ap.γ, ap.mc.P, ap.mc.state_values
1178-
M = P * ap.g(y)**(- γ)
1191+
β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G
1192+
M = P * G**(- γ)
11791193
11801194
# Make sure that a unique solution exists
1181-
err, _ = test_stability(M, β)
1182-
err.throw()
1195+
test_stability(M, β)
11831196
11841197
# Compute option price
11851198
p = consol_price(ap, ζ)
1199+
n = M.shape[0]
11861200
def step(i, w):
11871201
# Maximize across columns
11881202
w = jnp.maximum(β * M @ w, p - p_s)
11891203
return w
11901204
1191-
w = jax.lax.fori_loop(0, k, step, jnp.zeros(ap.n))
1205+
w = jax.lax.fori_loop(0, k, step, jnp.zeros(n))
11921206
11931207
return w
1208+
1209+
finite_call_option_jit = jax.jit(
1210+
checkify.checkify(finite_call_option)
1211+
)
11941212
```
11951213
11961214
Now let's compute the option values at `k=5` and `k=25`
11971215
11981216
```{code-cell} ipython3
11991217
fig, ax = plt.subplots()
12001218
for k in [5, 25]:
1201-
w = finite_horizon_call_option(apm, ζ, p_s, k)
1219+
err, w = finite_call_option_jit(apm, ζ, p_s, k)
1220+
err.throw()
12021221
ax.plot(s, w, label=rf'$k = {k}$')
12031222
ax.legend()
12041223
plt.show()

0 commit comments

Comments
 (0)