Skip to content

Commit 8233c0e

Browse files
🌐 [translation-sync] Misc changes to jax lectures (#66)
* Update translation: lectures/jax_intro.md * Update translation: .translate/state/jax_intro.md.yml * Update translation: lectures/numpy_vs_numba_vs_jax.md * Update translation: .translate/state/numpy_vs_numba_vs_jax.md.yml --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent e15b2ed commit 8233c0e

4 files changed

Lines changed: 102 additions & 66 deletions

File tree

.translate/state/jax_intro.md.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
source-sha: 450bafecd23db638602150b47f4272b98aad3146
2-
synced-at: "2026-04-14"
1+
source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28
2+
synced-at: "2026-05-14"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 7
6-
tool-version: 0.14.1
6+
tool-version: 0.15.0
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
source-sha: 450bafecd23db638602150b47f4272b98aad3146
2-
synced-at: "2026-04-14"
1+
source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28
2+
synced-at: "2026-05-14"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 3
6-
tool-version: 0.14.1
6+
tool-version: 0.15.0

lectures/jax_intro.md

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ translation:
2424
JAX as a NumPy Replacement::Differences::A Workaround: 变通方法
2525
Functional Programming: 函数式编程
2626
Functional Programming::Pure functions: 纯函数
27-
Functional Programming::Examples: 示例
27+
Functional Programming::Examples -- Pure and Impure: 示例——纯函数与非纯函数
2828
Functional Programming::Why Functional Programming?: 为什么要函数式编程?
2929
Random numbers: 随机数
3030
Random numbers::NumPy / MATLAB Approach: NumPy / MATLAB 方法
@@ -346,19 +346,20 @@ a
346346
* 不会改变全局状态
347347
* 不会修改传递给函数的数据(不可变数据)
348348

349-
### 示例
349+
### 示例——纯函数与非纯函数
350350

351351
以下是一个*非纯*函数的示例:
352352

353353
```{code-cell} ipython3
354354
tax_rate = 0.1
355-
prices = [10.0, 20.0]
356355
357356
def add_tax(prices):
358357
for i, price in enumerate(prices):
359358
prices[i] = price * (1 + tax_rate)
360-
print('Post-tax prices: ', prices)
361-
return prices
359+
360+
prices = [10.0, 20.0]
361+
add_tax(prices)
362+
prices
362363
```
363364

364365
这个函数不是纯函数,因为:
@@ -369,15 +370,21 @@ def add_tax(prices):
369370
以下是一个**版本:
370371

371372
```{code-cell} ipython3
372-
tax_rate = 0.1
373-
prices = (10.0, 20.0)
374373
375374
def add_tax_pure(prices, tax_rate):
376375
new_prices = [price * (1 + tax_rate) for price in prices]
377376
return new_prices
377+
378+
tax_rate = 0.1
379+
prices = (10.0, 20.0)
380+
after_tax_prices = add_tax_pure(prices, tax_rate)
381+
after_tax_prices
378382
```
379383

380-
这个纯版本通过函数参数使所有依赖关系变得明确,并且不修改任何外部状态。
384+
这是纯函数,因为:
385+
386+
* 所有依赖关系通过函数参数显式传递
387+
* 并且不修改任何外部状态
381388

382389
### 为什么要函数式编程?
383390

@@ -427,7 +434,7 @@ print(np.random.randn(2))
427434
* 它是非确定性的:相同的输入,不同的输出
428435
* 它有副作用:它修改了全局随机数生成器状态
429436

430-
在并行化下很危险——必须仔细控制每个线程中发生的事情
437+
在并行化下很危险——必须仔细控制每个线程中发生的事情
431438

432439
### JAX
433440

@@ -544,7 +551,11 @@ plt.show()
544551
下面的函数使用 `split` 生成 `k` 个(准)独立的随机 `n x n` 矩阵。
545552

546553
```{code-cell} ipython3
547-
def gen_random_matrices(key, n=2, k=3):
554+
def gen_random_matrices(
555+
key, # JAX key for random numbers
556+
n=2, # Matrices will be n x n
557+
k=3 # Number of matrices to generate
558+
):
548559
matrices = []
549560
for _ in range(k):
550561
key, subkey = jax.random.split(key)
@@ -566,7 +577,7 @@ gen_random_matrices(key)
566577

567578
### 好处
568579

569-
JAX 的显式性带来了显著的好处
580+
如上所述,这种显式性是很有价值的
570581

571582
* 可复现性:通过重用密钥轻松重现结果
572583
* 并行化:控制各个线程上发生的事情
@@ -647,7 +658,14 @@ with qe.Timer():
647658

648659
结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。
649660

650-
但我们仍在使用即时执行——大量内存和读写开销。
661+
这是因为单个数组操作在 GPU 上并行化了。
662+
663+
但我们仍在使用即时执行:
664+
665+
* 由于中间数组导致大量内存占用
666+
* 大量内存读写
667+
668+
此外,GPU 上还会启动许多独立的内核。
651669

652670
### 编译整个函数
653671

@@ -681,7 +699,8 @@ with qe.Timer():
681699

682700
* 基于整个计算序列的积极优化
683701
* 消除对硬件加速器的多次调用
684-
* 不创建中间数组
702+
703+
内存占用也大大降低——不再创建中间数组。
685704

686705
顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是:
687706

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ translation:
1313
Vectorized operations: 向量化运算
1414
Vectorized operations::Problem Statement: 问题陈述
1515
Vectorized operations::NumPy vectorization: NumPy 向量化
16+
Vectorized operations::Memory Issues: 内存问题
1617
Vectorized operations::A Comparison with Numba: 与 Numba 的比较
1718
Vectorized operations::Parallelized Numba: 并行化的 Numba
1819
Vectorized operations::Vectorized code with JAX: 使用 JAX 的向量化代码
@@ -152,16 +153,33 @@ for x in grid:
152153

153154
让我们切换到 NumPy 并使用更大的网格。
154155

156+
```{code-cell} ipython3
157+
grid = np.linspace(-3, 3, 3_000) # Large grid
158+
```
159+
160+
作为向量化的第一步,我们可能会尝试这样的方式
161+
162+
```{code-cell} ipython3
163+
# Large grid
164+
z = np.max(f(grid, grid)) # This is wrong!
165+
```
166+
167+
这里的问题是 `f(grid, grid)` 并不遵循嵌套循环。
168+
169+
从上图来看,它只计算了对角线上的 `f` 值。
170+
171+
要让 NumPy 在每个 `x,y` 对上计算 `f(x,y)`,我们需要使用 `np.meshgrid`
172+
155173
这里我们使用 `np.meshgrid` 来创建二维输入网格 `x``y`,使得 `f(x, y)` 能生成乘积网格上的所有计算结果。
156174

157175
```{code-cell} ipython3
158176
# Large grid
159177
grid = np.linspace(-3, 3, 3_000)
160178
161-
x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
179+
x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid
162180
163181
with qe.Timer():
164-
z_max_numpy = np.max(f(x, y))
182+
z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works
165183
```
166184

167185
在向量化版本中,所有循环都在编译后的代码中执行。
@@ -174,9 +192,29 @@ with qe.Timer():
174192
print(f"NumPy result: {z_max_numpy:.6f}")
175193
```
176194

195+
### 内存问题
196+
197+
我们在合理的时间内得到了正确的解——但内存使用量非常大。
198+
199+
虽然扁平数组占用内存较少
200+
201+
```{code-cell} ipython3
202+
grid.nbytes
203+
```
204+
205+
但网格矩阵是二维的,因此内存占用非常大
206+
207+
```{code-cell} ipython3
208+
x_mesh.nbytes + y_mesh.nbytes
209+
```
210+
211+
此外,NumPy 的即时执行会创建许多相同大小的中间数组!
212+
213+
在实际研究计算中,这种内存使用可能是一个大问题。
214+
177215
### 与 Numba 的比较
178216

179-
现在让我们看看能否使用简单循环的 Numba 获得更好的性能。
217+
让我们看看能否使用简单循环的 Numba 获得更好的性能。
180218

181219
```{code-cell} ipython3
182220
@numba.jit
@@ -207,13 +245,13 @@ with qe.Timer():
207245
compute_max_numba(grid)
208246
```
209247

210-
根据您的机器,Numba 版本可能比 NumPy 稍慢或稍快
248+
注意我们几乎不使用任何内存——我们只需要一维的 `grid`
211249

212-
在大多数情况下,我们发现 Numba 略胜一筹
250+
此外,执行速度也很好
213251

214-
一方面,NumPy 将高效的算术运算与一定程度的多线程结合在一起,这提供了优势
252+
在大多数机器上,Numba 版本会比 NumPy 稍快一些
215253

216-
另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格
254+
原因是高效的机器码加上更少的内存读写
217255

218256
### 并行化的 Numba
219257

@@ -307,39 +345,25 @@ with qe.Timer():
307345

308346
### JAX 加 vmap
309347

310-
NumPy 代码和上述 JAX 代码都存在一个问题:
311-
312-
虽然扁平数组占用内存较少
313-
314-
```{code-cell} ipython3
315-
grid.nbytes
316-
```
317-
318-
但网格矩阵的内存占用很大
348+
由于我们在上面使用了 `jax.jit`,我们避免了创建许多中间数组。
319349

320-
```{code-cell} ipython3
321-
x_mesh.nbytes + y_mesh.nbytes
322-
```
350+
但我们仍然创建了大数组 `z_max``x_mesh``y_mesh`
323351

324-
在实际研究计算中,这种额外的内存使用可能是一个大问题。
325-
326-
幸运的是,JAX 提供了一种使用 [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) 的不同方法。
327-
328-
`vmap` 的思路是将向量化分阶段进行,将一个对单个值进行操作的函数转化为对数组进行操作的函数。
352+
幸运的是,我们可以通过使用 [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) 来避免这一问题。
329353

330354
以下是我们将其应用于当前问题的方式。
331355

332356
```{code-cell} ipython3
333357
@jax.jit
334358
def compute_max_vmap(grid):
335359
# 构建一个对给定 y,在所有 x 上取最大值的函数
336-
f_vec_x_max = lambda y: jnp.max(f(grid, y))
360+
compute_column_max = lambda y: jnp.max(f(grid, y))
337361
# 向量化该函数,以便我们可以同时对所有 y 调用
338-
f_vec_max = jax.vmap(f_vec_x_max)
339-
# 在每个 y 处计算所有 x 上的最大值
340-
maxes = f_vec_max(grid)
341-
# 计算最大值的最大值并返回
342-
return jnp.max(maxes)
362+
vectorized_compute_column_max = jax.vmap(compute_column_max)
363+
# 在每一行处计算列最大值
364+
column_maxes = vectorized_compute_column_max(grid)
365+
# 计算列最大值的最大值并返回
366+
return jnp.max(column_maxes)
343367
```
344368

345369
注意我们从不创建
@@ -348,6 +372,8 @@ def compute_max_vmap(grid):
348372
* 二维网格 `y_mesh`
349373
* 二维数组 `f(x,y)`
350374

375+
与 Numba 类似,我们只使用扁平数组 `grid`
376+
351377
并且由于所有内容都在单个 `@jax.jit` 下,编译器可以将所有操作融合为一个优化的内核。
352378

353379
让我们试试。
@@ -378,13 +404,11 @@ with qe.Timer():
378404

379405
它在速度(通过 JIT 编译和并行化)和内存效率(通过 vmap)两方面都优于 NumPy。
380406

381-
此外,`vmap` 方法有时可以带来更清晰的代码
407+
在 GPU 上运行时,它也优于 Numba
382408

383-
虽然 Numba 令人印象深刻,但 JAX 的优势在于,对于完全向量化的运算,我们可以在配备硬件加速器的机器上运行完全相同的代码,并在无需额外努力的情况下获得所有收益。
384-
385-
此外,JAX 已经知道如何有效地并行化许多常见的数组运算,这是快速执行的关键。
386-
387-
对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。
409+
```{note}
410+
Numba 可以通过 `numba.cuda` 支持 GPU 编程,但这样我们需要手动进行并行化。对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。
411+
```
388412

389413
## 顺序运算
390414

@@ -536,8 +560,6 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
536560

537561
虽然 JAX 的 `at[t].set` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读。
538562

539-
对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。
540-
541563
## 总体建议
542564

543565
让我们退一步,总结一下各方的权衡取舍。
@@ -550,17 +572,12 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然
550572

551573
此外,JAX 函数支持自动微分,我们将在 {doc}`autodiff` 中进一步探讨。
552574

553-
对于**顺序操作**,Numba 具有明显优势
575+
对于**顺序操作**,Numba 具有更简洁的语法
554576

555577
代码自然且可读——只需一个带有装饰器的 Python 循环——性能也非常出色。
556578

557-
JAX 可以通过 `lax.scan` 处理顺序问题,但语法不够直观。
558-
559-
```{note}
560-
`lax.scan` 的一个重要优势是它支持通过循环进行自动微分,而 Numba 无法做到这一点。
561-
如果您需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
562-
```
579+
JAX 可以通过 `lax.fori_loop``lax.scan` 处理顺序问题,但语法不够直观。
563580

564-
在实践中,许多问题涉及两种模式的混合
581+
另一方面,JAX 版本支持自动微分
565582

566-
一个实用的经验法则是:对于新项目默认使用 JAX,尤其是当硬件加速或可微分性可能有用时,而当您有一个需要快速且可读的紧凑顺序循环时,则选用 Numba
583+
例如,当我们希望计算轨迹对模型参数的敏感性时,这可能会很有用

0 commit comments

Comments
 (0)