@@ -21,11 +21,11 @@ translation:
2121 JAX as a NumPy Replacement::Differences::Size Experiment : 大小实验
2222 JAX as a NumPy Replacement::Differences::Precision : 精度
2323 JAX as a NumPy Replacement::Differences::Immutability : 不可变性
24- JAX as a NumPy Replacement::Differences::A Workaround : 变通方法
24+ JAX as a NumPy Replacement::Differences::A workaround : 变通方法
2525 Functional Programming : 函数式编程
2626 Functional Programming::Pure functions : 纯函数
2727 Functional Programming::Examples : 示例
28- Functional Programming::Why Functional Programming? : 为什么要函数式编程 ?
28+ Functional Programming::Why Functional Programming? : 为什么使用函数式编程 ?
2929 Random numbers : 随机数
3030 Random numbers::Random number generation : 随机数生成
3131 Random numbers::Why explicit random state? : 为什么要显式随机状态?
@@ -37,10 +37,6 @@ translation:
3737 JIT Compilation::Compiling the Whole Function : 编译整个函数
3838 JIT Compilation::How JIT compilation works : JIT 编译的工作原理
3939 JIT Compilation::Compiling non-pure functions : 编译非纯函数
40- Vectorization with vmap : 使用 vmap 进行向量化
41- Vectorization with vmap::A simple example : 一个简单的示例
42- Vectorization with vmap::Combining transformations : 组合变换
43- Automatic differentiation : a preview: 自动微分:预览
4440 Exercises : 练习
4541---
4642
@@ -77,17 +73,17 @@ import numpy as np
7773import quantecon as qe
7874```
7975
76+ 注意我们导入了 ` jax.numpy as jnp ` ,它提供了类似 NumPy 的接口。
77+
8078## JAX 作为 NumPy 的替代品
8179
82- 让我们来看看 JAX 和 NumPy 之间的异同 。
80+ JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API 。
8381
84- ### 相似之处
82+ 这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。
8583
86- 上面我们导入了 ` jax.numpy as jnp ` ,它提供了类似 NumPy 的数组操作接口。
87-
88- JAX 的一个吸引人之处在于,这个接口在尽可能的情况下遵循 NumPy API。
84+ 让我们来看看 JAX 和 NumPy 之间的异同。
8985
90- 因此,我们通常可以将 JAX 作为 NumPy 的直接替代品使用。
86+ ### 相似之处
9187
9288以下是使用 ` jnp ` 进行的一些标准数组操作:
9389
@@ -107,7 +103,7 @@ print(jnp.sum(a))
107103print(jnp.dot(a, a))
108104```
109105
110- 但需要注意的是 ,数组对象 ` a ` 并不是 NumPy 数组:
106+ 然而 ,数组对象 ` a ` 并不是 NumPy 数组:
111107
112108``` {code-cell} ipython3
113109a
117113type(a)
118114```
119115
120- 即使是数组上的标量值映射也会返回 JAX 数组而非标量 !
116+ 即使是数组上的标量值映射也会返回 JAX 数组,而不是标量 !
121117
122118``` {code-cell} ipython3
123119jnp.sum(a)
@@ -130,18 +126,16 @@ jnp.sum(a)
130126(jax_speed)=
131127#### 速度!
132128
133- 一个主要差异是 JAX 更快——有时快得多。
134-
135- 为了说明这一点,假设我们想在许多点处计算余弦函数。
129+ 假设我们想在许多点上计算余弦函数。
136130
137131``` {code-cell}
138132n = 50_000_000
139- x = np.linspace(0, 10, n) # NumPy array
133+ x = np.linspace(0, 10, n)
140134```
141135
142136##### 使用 NumPy
143137
144- 让我们用 NumPy 来试试
138+ 让我们先用 NumPy 试试:
145139
146140``` {code-cell}
147141with qe.Timer():
@@ -159,7 +153,7 @@ with qe.Timer():
159153
160154这里
161155
162- * NumPy 使用预编译的二进制文件将余弦函数应用于浮点数数组
156+ * NumPy 使用预编译的二进制文件对浮点数数组应用余弦函数
163157* 该二进制文件在本地机器的 CPU 上运行
164158
165159##### 使用 JAX
@@ -181,8 +175,11 @@ with qe.Timer():
181175```
182176
183177``` {note}
184- 上面的 `block_until_ready` 方法会阻塞解释器,直到计算结果返回。
185- 这对于计时是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前继续运行。
178+ 这里,为了测量实际速度,我们使用 `block_until_ready` 方法来阻塞解释器,直到计算结果返回。
179+
180+ 这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。
181+
182+ 对于非计时代码,可以删除包含 `block_until_ready` 的那一行。
186183```
187184
188185再来计时一次。
@@ -277,8 +274,7 @@ a[0] = 1
277274a
278275```
279276
280- 在 JAX 中,这会失败 😱。
281-
277+ 在 JAX 中,这会失败!
282278
283279``` {code-cell} ipython3
284280a = jnp.linspace(0, 1, 3)
@@ -290,21 +286,13 @@ try:
290286 a[0] = 1
291287except Exception as e:
292288 print(e)
293-
294289```
295290
296- JAX 的设计者选择将数组设为不可变的,因为
297-
298- 1 . JAX 使用* 函数式编程风格* ,并且
299- 2 . 函数式编程通常避免可变数据
300-
301- 我们将在 {ref}` 下面 <jax_func> ` 讨论这些思想。
291+ JAX 的设计者选择将数组设为不可变的,因为 JAX 使用函数式编程风格,我们将在下面讨论这一点。
302292
303-
304- (jax_at_workaround)=
305293#### 变通方法
306294
307- JAX 确实通过 [ ` at ` 方法] ( https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html ) 提供了原地数组修改的直接替代方案 。
295+ 我们注意到 JAX 确实提供了一种替代原地数组修改的方式,使用 [ ` at ` 方法] ( https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html ) 。
308296
309297``` {code-cell} ipython3
310298a = jnp.linspace(0, 1, 3)
326314
327315(尽管它在 JIT 编译的函数中实际上可以很高效——但现在先把这个放在一边。)
328316
329-
330- (jax_func)=
331317## 函数式编程
332318
333319来自 JAX 的文档:
@@ -860,40 +846,6 @@ fast_batch_mm_diff(X)
860846
861847` jit ` 、` vmap ` 以及(我们接下来将看到的)` grad ` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
862848
863-
864- ## 自动微分:预览
865-
866- JAX 可以使用自动微分来计算梯度。
867-
868- 这对于优化和求解非线性系统非常有用。
869-
870- 以下是一个简单的示例,涉及函数 $f(x) = x^2 / 2$:
871-
872- ``` {code-cell} ipython3
873- def f(x):
874- return (x**2) / 2
875-
876- f_prime = jax.grad(f)
877- ```
878-
879- ``` {code-cell} ipython3
880- f_prime(10.0)
881- ```
882-
883- 让我们绘制函数和导数,注意 $f'(x) = x$。
884-
885- ``` {code-cell} ipython3
886- fig, ax = plt.subplots()
887- x_grid = jnp.linspace(-4, 4, 200)
888- ax.plot(x_grid, f(x_grid), label="$f$")
889- ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
890- ax.legend(loc='upper center')
891- plt.show()
892- ```
893-
894- 自动微分是一个有许多经济学和金融学应用的深层主题。我们在{doc}` 自动微分讲座 <autodiff> ` 中提供了更深入的讨论。
895-
896-
897849## 练习
898850
899851
0 commit comments