Skip to content

Commit 2d0d754

Browse files
committed
Update translation: lectures/numpy_vs_numba_vs_jax.md
1 parent 70d05ea commit 2d0d754

1 file changed

Lines changed: 95 additions & 85 deletions

File tree

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 95 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ translation:
2121
Sequential operations: 顺序运算
2222
Sequential operations::Numba Version: Numba 版本
2323
Sequential operations::JAX Version: JAX 版本
24+
Sequential operations::JAX Version::First Attempt: 第一种尝试
25+
Sequential operations::JAX Version::Second Attempt: 第二种尝试
2426
Sequential operations::Summary: 总结
2527
Overall recommendations: 总体建议
2628
---
@@ -143,33 +145,34 @@ m = -np.inf
143145
for x in grid:
144146
for y in grid:
145147
z = f(x, y)
146-
if z > m:
147-
m = z
148+
m = max(m, z)
148149
```
149150

150151
### NumPy 向量化
151152

152-
如果我们切换到 NumPy 风格的向量化,就可以使用更大的网格,并且代码执行速度相对较快
153+
让我们切换到 NumPy 并使用更大的网格
153154

154155
这里我们使用 `np.meshgrid` 来创建二维输入网格 `x``y`,使得 `f(x, y)` 能生成乘积网格上的所有计算结果。
155156

156-
(这一策略可以追溯到 MATLAB。)
157-
158157
```{code-cell} ipython3
158+
# Large grid
159159
grid = np.linspace(-3, 3, 3_000)
160-
x, y = np.meshgrid(grid, grid)
160+
161+
x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
161162
162163
with qe.Timer():
163164
z_max_numpy = np.max(f(x, y))
164-
165-
print(f"NumPy result: {z_max_numpy:.6f}")
166165
```
167166

168167
在向量化版本中,所有循环都在编译后的代码中执行。
169168

170-
此外,NumPy 使用隐式多线程,因此至少会发生一定程度的并行化。
169+
使用 `meshgrid` 可以复现嵌套的 for 循环。
170+
171+
输出结果应接近于 1:
171172

172-
(并行化效率不高,因为二进制文件在看到数组 `x``y` 的大小之前就已经被编译了。)
173+
```{code-cell} ipython3
174+
print(f"NumPy result: {z_max_numpy:.6f}")
175+
```
173176

174177
### 与 Numba 的比较
175178

@@ -194,8 +197,6 @@ grid = np.linspace(-3, 3, 3_000)
194197
with qe.Timer():
195198
# First run
196199
z_max_numba = compute_max_numba(grid)
197-
198-
print(f"Numba result: {z_max_numba:.6f}")
199200
```
200201

201202
让我们再次运行以消除编译时间。
@@ -238,8 +239,6 @@ def compute_max_numba_parallel(grid):
238239
with qe.Timer():
239240
# First run
240241
z_max_parallel = compute_max_numba_parallel(grid)
241-
242-
print(f"Numba result: {z_max_parallel:.6f}")
243242
```
244243

245244
以下是预编译版本的计时结果。
@@ -250,15 +249,19 @@ with qe.Timer():
250249
compute_max_numba_parallel(grid)
251250
```
252251

253-
如果您有多个核心,您应该能在此处看到并行化带来的一定收益
252+
如果您有多个核心,您应该能在此处看到并行化带来的收益
254253

255-
对于更强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来显著的速度提升。
254+
让我们确认结果仍然正确(接近于 1):
256255

257-
### 使用 JAX 的向量化代码
256+
```{code-cell} ipython3
257+
print(f"Numba result: {z_max_parallel:.6f}")
258+
```
258259

259-
表面上,JAX 中的向量化代码与 NumPy 代码类似
260+
对于强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来有用的速度提升
260261

261-
但两者之间也存在一些差异,我们在这里加以强调。
262+
### 使用 JAX 的向量化代码
263+
264+
让我们尝试用 JAX 复现 NumPy 的向量化方法。
262265

263266
让我们从函数开始,将 `np` 替换为 `jnp` 并添加 `jax.jit`
264267

@@ -269,7 +272,7 @@ def f(x, y):
269272
270273
```
271274

272-
NumPy 一样,为了获得正确的形状和正确的嵌套 `for` 循环计算,我们可以使用专为此目的设计的 `meshgrid` 操作
275+
我们使用 NumPy 风格的 meshgrid 方法
273276

274277
```{code-cell} ipython3
275278
grid = jnp.linspace(-3, 3, 3_000)
@@ -326,60 +329,24 @@ x_mesh.nbytes + y_mesh.nbytes
326329

327330
以下是我们将其应用于当前问题的方式。
328331

329-
```{code-cell} ipython3
330-
# 设置 f,使其在给定任意 y 时,对所有 x 计算 f(x, y)
331-
f_vec_x = lambda y: f(grid, y)
332-
# 创建第二个函数,将此操作在所有 y 上向量化
333-
f_vec = jax.vmap(f_vec_x)
334-
```
335-
336-
现在,当以扁平数组 `grid` 调用时,`f_vec` 将在每个 `x,y` 处计算 `f(x,y)`
337-
338-
让我们看看计时结果:
339-
340-
```{code-cell} ipython3
341-
with qe.Timer():
342-
z_max = jnp.max(f_vec(grid))
343-
z_max.block_until_ready()
344-
345-
print(f"JAX vmap v1 result: {z_max:.6f}")
346-
```
347-
348-
```{code-cell} ipython3
349-
with qe.Timer():
350-
z_max = jnp.max(f_vec(grid))
351-
z_max.block_until_ready()
352-
```
353-
354-
通过避免使用大型输入数组 `x_mesh``y_mesh`,这个 `vmap` 版本使用的内存少得多,运行时间变化不大。
355-
356-
这很好——但我们还有进一步提升速度的空间!
357-
358-
首先请注意,上面的代码计算了完整的二维数组 `f(x,y)`,这会产生开销,然后再取最大值。
359-
360-
其次,`jnp.max` 调用位于 JIT 编译函数 `f` 之外,因此编译器无法将这些操作融合为单个内核。
361-
362-
我们可以通过将最大值操作移到内部并将所有内容包装在一个 `@jax.jit` 中来解决这两个问题:
363-
364332
```{code-cell} ipython3
365333
@jax.jit
366334
def compute_max_vmap(grid):
367-
# 构建一个沿每行取最大值的函数
335+
# 构建一个对给定 y,在所有 x 上取最大值的函数
368336
f_vec_x_max = lambda y: jnp.max(f(grid, y))
369-
# 向量化该函数,以便我们可以同时对所有行调用
337+
# 向量化该函数,以便我们可以同时对所有 y 调用
370338
f_vec_max = jax.vmap(f_vec_x_max)
371-
# 调用向量化函数并取最大值
372-
return jnp.max(f_vec_max(grid))
339+
# 在每个 y 处计算所有 x 上的最大值
340+
maxes = f_vec_max(grid)
341+
# 计算最大值的最大值并返回
342+
return jnp.max(maxes)
373343
```
374344

375-
其中
376-
377-
* `f_vec_x_max` 计算任意给定行的最大值
378-
* `f_vec_max` 是一个向量化版本,可以并行计算所有行的最大值。
379-
380-
我们将此函数应用于所有行,然后取各行最大值中的最大值。
345+
注意我们从不创建
381346

382-
由于将最大值操作移到内部,我们永远不会构建完整的二维数组 `f(x,y)`,从而节省了更多内存。
347+
* 二维网格 `x_mesh`
348+
* 二维网格 `y_mesh`
349+
* 二维数组 `f(x,y)`
383350

384351
并且由于所有内容都在单个 `@jax.jit` 下,编译器可以将所有操作融合为一个优化的内核。
385352

@@ -461,21 +428,70 @@ with qe.Timer():
461428

462429
Numba 非常高效地处理了这个顺序运算。
463430

464-
注意,JIT 编译完成后,第二次运行明显更快。
431+
### JAX 版本
465432

466-
Numba 的编译通常相当快,对于像这样的顺序运算,生成的代码性能非常出色
433+
我们不能直接用 `jax.jit` 替换 `numba.jit`,因为 JAX 数组是不可变的
467434

468-
### JAX 版本
435+
但我们仍然可以实现这一运算。
436+
437+
#### 第一种尝试
469438

470-
现在让我们使用 `lax.scan` 创建一个 JAX 版本:
439+
以下是使用 `at[t].set` 语法的变通方案,我们在 {ref}`JAX 讲座中讨论过 <jax_at_workaround>`
471440

472-
(我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。)
441+
我们将应用 `lax.fori_loop`,这是一种可以被 XLA 编译的 for 循环版本。
473442

474443
```{code-cell} ipython3
475444
cpu = jax.devices("cpu")[0]
476445
477-
@partial(jax.jit, static_argnames=('n',), device=cpu)
478-
def qm_jax(x0, n, α=4.0):
446+
@partial(jax.jit, static_argnames=("n",), device=cpu)
447+
def qm_jax_fori(x0, n, α=4.0):
448+
449+
x = jnp.empty(n + 1).at[0].set(x0)
450+
451+
def update(t, x):
452+
return x.at[t + 1].set(α * x[t] * (1 - x[t]))
453+
454+
x = lax.fori_loop(0, n, update, x)
455+
return x
456+
457+
```
458+
459+
* 我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。
460+
* 我们通过 `device=cpu` 将计算固定到 CPU,因为这种顺序工作负载由许多小型运算组成,几乎没有机会利用 GPU 并行性。
461+
462+
重要提示:虽然 `at[t].set` 看起来在每一步都创建了一个新数组,但在 JIT 编译的函数内部,编译器会检测到旧数组不再需要,并就地执行更新!
463+
464+
让我们使用相同的参数计时:
465+
466+
```{code-cell} ipython3
467+
with qe.Timer():
468+
# First run
469+
x_jax = qm_jax_fori(0.1, n)
470+
# Hold interpreter
471+
x_jax.block_until_ready()
472+
```
473+
474+
让我们再次运行以消除编译开销:
475+
476+
```{code-cell} ipython3
477+
with qe.Timer():
478+
# Second run
479+
x_jax = qm_jax_fori(0.1, n)
480+
# Hold interpreter
481+
x_jax.block_until_ready()
482+
```
483+
484+
JAX 对于这种顺序运算也相当高效!
485+
486+
#### 第二种尝试
487+
488+
还有另一种使用 `lax.scan` 实现该循环的方式。
489+
490+
这种替代方案可以说更符合 JAX 的函数式风格——尽管语法难以记忆。
491+
492+
```{code-cell} ipython3
493+
@partial(jax.jit, static_argnames=("n",), device=cpu)
494+
def qm_jax_scan(x0, n, α=4.0):
479495
def update(x, t):
480496
x_new = α * x * (1 - x)
481497
return x_new, x_new
@@ -486,16 +502,12 @@ def qm_jax(x0, n, α=4.0):
486502

487503
这段代码不易阅读,但本质上,`lax.scan` 反复调用 `update` 并将返回值 `x_new` 累积到一个数组中。
488504

489-
```{note}
490-
我们在 `jax.jit` 装饰器中指定了 `device=cpu`,因为该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。
491-
```
492-
493505
让我们使用相同的参数计时:
494506

495507
```{code-cell} ipython3
496508
with qe.Timer():
497509
# First run
498-
x_jax = qm_jax(0.1, n)
510+
x_jax = qm_jax_scan(0.1, n)
499511
# Hold interpreter
500512
x_jax.block_until_ready()
501513
```
@@ -505,26 +517,24 @@ with qe.Timer():
505517
```{code-cell} ipython3
506518
with qe.Timer():
507519
# Second run
508-
x_jax = qm_jax(0.1, n)
520+
x_jax = qm_jax_scan(0.1, n)
509521
# Hold interpreter
510522
x_jax.block_until_ready()
511523
```
512524

513-
JAX 对于这种顺序运算也相当高效。
514-
515-
JAX 和 Numba 在编译后都能提供出色的性能。
525+
令人惊讶的是,JAX 在编译后也能提供出色的性能。
516526

517527
### 总结
518528

519-
虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,*在代码可读性和易用性方面存在显著差异*
529+
虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,但在代码可读性和易用性方面存在差异
520530

521531
Numba 版本简单直观,易于阅读:我们只需分配一个数组,然后使用标准 Python 循环逐元素填充它。
522532

523533
这正是大多数程序员思考该算法的方式。
524534

525-
另一方面,JAX 版本需要使用 `lax.scan`这明显不够直观
535+
另一方面,JAX 版本需要使用 `lax.fori_loop``lax.scan`这两者都不如标准 Python 循环直观
526536

527-
此外,JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难
537+
虽然 JAX `at[t].set` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读
528538

529539
对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。
530540

0 commit comments

Comments
 (0)