@@ -13,6 +13,7 @@ translation:
1313 Vectorized operations : 向量化运算
1414 Vectorized operations::Problem Statement : 问题陈述
1515 Vectorized operations::NumPy vectorization : NumPy 向量化
16+ Vectorized operations::Memory Issues : 内存问题
1617 Vectorized operations::A Comparison with Numba : 与 Numba 的比较
1718 Vectorized operations::Parallelized Numba : 并行化的 Numba
1819 Vectorized operations::Vectorized code with JAX : 使用 JAX 的向量化代码
@@ -152,16 +153,33 @@ for x in grid:
152153
153154让我们切换到 NumPy 并使用更大的网格。
154155
156+ ``` {code-cell} ipython3
157+ grid = np.linspace(-3, 3, 3_000) # Large grid
158+ ```
159+
160+ 作为向量化的第一步,我们可能会尝试这样的方式
161+
162+ ``` {code-cell} ipython3
163+ # Large grid
164+ z = np.max(f(grid, grid)) # This is wrong!
165+ ```
166+
167+ 这里的问题是 ` f(grid, grid) ` 并不遵循嵌套循环。
168+
169+ 从上图来看,它只计算了对角线上的 ` f ` 值。
170+
171+ 要让 NumPy 在每个 ` x,y ` 对上计算 ` f(x,y) ` ,我们需要使用 ` np.meshgrid ` 。
172+
155173这里我们使用 ` np.meshgrid ` 来创建二维输入网格 ` x ` 和 ` y ` ,使得 ` f(x, y) ` 能生成乘积网格上的所有计算结果。
156174
157175``` {code-cell} ipython3
158176# Large grid
159177grid = np.linspace(-3, 3, 3_000)
160178
161- x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
179+ x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid
162180
163181with qe.Timer():
164- z_max_numpy = np.max(f(x, y))
182+ z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works
165183```
166184
167185在向量化版本中,所有循环都在编译后的代码中执行。
@@ -174,9 +192,29 @@ with qe.Timer():
174192print(f"NumPy result: {z_max_numpy:.6f}")
175193```
176194
195+ ### 内存问题
196+
197+ 我们在合理的时间内得到了正确的解——但内存使用量非常大。
198+
199+ 虽然扁平数组占用内存较少
200+
201+ ``` {code-cell} ipython3
202+ grid.nbytes
203+ ```
204+
205+ 但网格矩阵是二维的,因此内存占用非常大
206+
207+ ``` {code-cell} ipython3
208+ x_mesh.nbytes + y_mesh.nbytes
209+ ```
210+
211+ 此外,NumPy 的即时执行会创建许多相同大小的中间数组!
212+
213+ 在实际研究计算中,这种内存使用可能是一个大问题。
214+
177215### 与 Numba 的比较
178216
179- 现在让我们看看能否使用简单循环的 Numba 获得更好的性能。
217+ 让我们看看能否使用简单循环的 Numba 获得更好的性能。
180218
181219``` {code-cell} ipython3
182220@numba.jit
@@ -207,13 +245,13 @@ with qe.Timer():
207245 compute_max_numba(grid)
208246```
209247
210- 根据您的机器,Numba 版本可能比 NumPy 稍慢或稍快 。
248+ 注意我们几乎不使用任何内存——我们只需要一维的 ` grid ` 。
211249
212- 在大多数情况下,我们发现 Numba 略胜一筹 。
250+ 此外,执行速度也很好 。
213251
214- 一方面, NumPy 将高效的算术运算与一定程度的多线程结合在一起,这提供了优势 。
252+ 在大多数机器上,Numba 版本会比 NumPy 稍快一些 。
215253
216- 另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格 。
254+ 原因是高效的机器码加上更少的内存读写 。
217255
218256### 并行化的 Numba
219257
@@ -307,39 +345,25 @@ with qe.Timer():
307345
308346### JAX 加 vmap
309347
310- NumPy 代码和上述 JAX 代码都存在一个问题:
311-
312- 虽然扁平数组占用内存较少
313-
314- ``` {code-cell} ipython3
315- grid.nbytes
316- ```
317-
318- 但网格矩阵的内存占用很大
348+ 由于我们在上面使用了 ` jax.jit ` ,我们避免了创建许多中间数组。
319349
320- ``` {code-cell} ipython3
321- x_mesh.nbytes + y_mesh.nbytes
322- ```
350+ 但我们仍然创建了大数组 ` z_max ` 、` x_mesh ` 和 ` y_mesh ` 。
323351
324- 在实际研究计算中,这种额外的内存使用可能是一个大问题。
325-
326- 幸运的是,JAX 提供了一种使用 [ jax.vmap] ( https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html ) 的不同方法。
327-
328- ` vmap ` 的思路是将向量化分阶段进行,将一个对单个值进行操作的函数转化为对数组进行操作的函数。
352+ 幸运的是,我们可以通过使用 [ jax.vmap] ( https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html ) 来避免这一问题。
329353
330354以下是我们将其应用于当前问题的方式。
331355
332356``` {code-cell} ipython3
333357@jax.jit
334358def compute_max_vmap(grid):
335359 # 构建一个对给定 y,在所有 x 上取最大值的函数
336- f_vec_x_max = lambda y: jnp.max(f(grid, y))
360+ compute_column_max = lambda y: jnp.max(f(grid, y))
337361 # 向量化该函数,以便我们可以同时对所有 y 调用
338- f_vec_max = jax.vmap(f_vec_x_max )
339- # 在每个 y 处计算所有 x 上的最大值
340- maxes = f_vec_max (grid)
341- # 计算最大值的最大值并返回
342- return jnp.max(maxes )
362+ vectorized_compute_column_max = jax.vmap(compute_column_max )
363+ # 在每一行处计算列最大值
364+ column_maxes = vectorized_compute_column_max (grid)
365+ # 计算列最大值的最大值并返回
366+ return jnp.max(column_maxes )
343367```
344368
345369注意我们从不创建
@@ -348,6 +372,8 @@ def compute_max_vmap(grid):
348372* 二维网格 ` y_mesh ` 或
349373* 二维数组 ` f(x,y) `
350374
375+ 与 Numba 类似,我们只使用扁平数组 ` grid ` 。
376+
351377并且由于所有内容都在单个 ` @jax.jit ` 下,编译器可以将所有操作融合为一个优化的内核。
352378
353379让我们试试。
@@ -378,13 +404,11 @@ with qe.Timer():
378404
379405它在速度(通过 JIT 编译和并行化)和内存效率(通过 vmap)两方面都优于 NumPy。
380406
381- 此外, ` vmap ` 方法有时可以带来更清晰的代码 。
407+ 在 GPU 上运行时,它也优于 Numba 。
382408
383- 虽然 Numba 令人印象深刻,但 JAX 的优势在于,对于完全向量化的运算,我们可以在配备硬件加速器的机器上运行完全相同的代码,并在无需额外努力的情况下获得所有收益。
384-
385- 此外,JAX 已经知道如何有效地并行化许多常见的数组运算,这是快速执行的关键。
386-
387- 对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。
409+ ``` {note}
410+ Numba 可以通过 `numba.cuda` 支持 GPU 编程,但这样我们需要手动进行并行化。对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。
411+ ```
388412
389413## 顺序运算
390414
@@ -536,8 +560,6 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
536560
537561虽然 JAX 的 ` at[t].set ` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读。
538562
539- 对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。
540-
541563## 总体建议
542564
543565让我们退一步,总结一下各方的权衡取舍。
@@ -550,17 +572,12 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
550572
551573此外,JAX 函数支持自动微分,我们将在 {doc}` autodiff ` 中进一步探讨。
552574
553- 对于** 顺序操作** ,Numba 具有明显优势 。
575+ 对于** 顺序操作** ,Numba 具有更简洁的语法 。
554576
555577代码自然且可读——只需一个带有装饰器的 Python 循环——性能也非常出色。
556578
557- JAX 可以通过 ` lax.scan ` 处理顺序问题,但语法不够直观。
558-
559- ``` {note}
560- `lax.scan` 的一个重要优势是它支持通过循环进行自动微分,而 Numba 无法做到这一点。
561- 如果您需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
562- ```
579+ JAX 可以通过 ` lax.fori_loop ` 或 ` lax.scan ` 处理顺序问题,但语法不够直观。
563580
564- 在实践中,许多问题涉及两种模式的混合 。
581+ 另一方面,JAX 版本支持自动微分 。
565582
566- 一个实用的经验法则是:对于新项目默认使用 JAX,尤其是当硬件加速或可微分性可能有用时,而当您有一个需要快速且可读的紧凑顺序循环时,则选用 Numba 。
583+ 例如,当我们希望计算轨迹对模型参数的敏感性时,这可能会很有用 。
0 commit comments