@@ -21,6 +21,8 @@ translation:
2121 Sequential operations : 顺序运算
2222 Sequential operations::Numba Version : Numba 版本
2323 Sequential operations::JAX Version : JAX 版本
24+ Sequential operations::JAX Version::First Attempt : 第一种尝试
25+ Sequential operations::JAX Version::Second Attempt : 第二种尝试
2426 Sequential operations::Summary : 总结
2527 Overall recommendations : 总体建议
2628---
@@ -143,33 +145,34 @@ m = -np.inf
143145for x in grid:
144146 for y in grid:
145147 z = f(x, y)
146- if z > m:
147- m = z
148+ m = max(m, z)
148149```
149150
150151### NumPy 向量化
151152
152- 如果我们切换到 NumPy 风格的向量化,就可以使用更大的网格,并且代码执行速度相对较快 。
153+ 让我们切换到 NumPy 并使用更大的网格 。
153154
154155这里我们使用 ` np.meshgrid ` 来创建二维输入网格 ` x ` 和 ` y ` ,使得 ` f(x, y) ` 能生成乘积网格上的所有计算结果。
155156
156- (这一策略可以追溯到 MATLAB。)
157-
158157``` {code-cell} ipython3
158+ # Large grid
159159grid = np.linspace(-3, 3, 3_000)
160- x, y = np.meshgrid(grid, grid)
160+
161+ x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
161162
162163with qe.Timer():
163164 z_max_numpy = np.max(f(x, y))
164-
165- print(f"NumPy result: {z_max_numpy:.6f}")
166165```
167166
168167在向量化版本中,所有循环都在编译后的代码中执行。
169168
170- 此外,NumPy 使用隐式多线程,因此至少会发生一定程度的并行化。
169+ 使用 ` meshgrid ` 可以复现嵌套的 for 循环。
170+
171+ 输出结果应接近于 1:
171172
172- (并行化效率不高,因为二进制文件在看到数组 ` x ` 和 ` y ` 的大小之前就已经被编译了。)
173+ ``` {code-cell} ipython3
174+ print(f"NumPy result: {z_max_numpy:.6f}")
175+ ```
173176
174177### 与 Numba 的比较
175178
@@ -194,8 +197,6 @@ grid = np.linspace(-3, 3, 3_000)
194197with qe.Timer():
195198 # First run
196199 z_max_numba = compute_max_numba(grid)
197-
198- print(f"Numba result: {z_max_numba:.6f}")
199200```
200201
201202让我们再次运行以消除编译时间。
@@ -238,8 +239,6 @@ def compute_max_numba_parallel(grid):
238239with qe.Timer():
239240 # First run
240241 z_max_parallel = compute_max_numba_parallel(grid)
241-
242- print(f"Numba result: {z_max_parallel:.6f}")
243242```
244243
245244以下是预编译版本的计时结果。
@@ -250,15 +249,19 @@ with qe.Timer():
250249 compute_max_numba_parallel(grid)
251250```
252251
253- 如果您有多个核心,您应该能在此处看到并行化带来的一定收益 。
252+ 如果您有多个核心,您应该能在此处看到并行化带来的收益 。
254253
255- 对于更强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来显著的速度提升。
254+ 让我们确认结果仍然正确(接近于 1):
256255
257- ### 使用 JAX 的向量化代码
256+ ``` {code-cell} ipython3
257+ print(f"Numba result: {z_max_parallel:.6f}")
258+ ```
258259
259- 表面上,JAX 中的向量化代码与 NumPy 代码类似 。
260+ 对于强大的机器和更大的网格尺寸,即使在 CPU 上,并行化也能带来有用的速度提升 。
260261
261- 但两者之间也存在一些差异,我们在这里加以强调。
262+ ### 使用 JAX 的向量化代码
263+
264+ 让我们尝试用 JAX 复现 NumPy 的向量化方法。
262265
263266让我们从函数开始,将 ` np ` 替换为 ` jnp ` 并添加 ` jax.jit `
264267
@@ -269,7 +272,7 @@ def f(x, y):
269272
270273```
271274
272- 与 NumPy 一样,为了获得正确的形状和正确的嵌套 ` for ` 循环计算,我们可以使用专为此目的设计的 ` meshgrid ` 操作 :
275+ 我们使用 NumPy 风格的 meshgrid 方法 :
273276
274277``` {code-cell} ipython3
275278grid = jnp.linspace(-3, 3, 3_000)
@@ -326,60 +329,24 @@ x_mesh.nbytes + y_mesh.nbytes
326329
327330以下是我们将其应用于当前问题的方式。
328331
329- ``` {code-cell} ipython3
330- # 设置 f,使其在给定任意 y 时,对所有 x 计算 f(x, y)
331- f_vec_x = lambda y: f(grid, y)
332- # 创建第二个函数,将此操作在所有 y 上向量化
333- f_vec = jax.vmap(f_vec_x)
334- ```
335-
336- 现在,当以扁平数组 ` grid ` 调用时,` f_vec ` 将在每个 ` x,y ` 处计算 ` f(x,y) ` 。
337-
338- 让我们看看计时结果:
339-
340- ``` {code-cell} ipython3
341- with qe.Timer():
342- z_max = jnp.max(f_vec(grid))
343- z_max.block_until_ready()
344-
345- print(f"JAX vmap v1 result: {z_max:.6f}")
346- ```
347-
348- ``` {code-cell} ipython3
349- with qe.Timer():
350- z_max = jnp.max(f_vec(grid))
351- z_max.block_until_ready()
352- ```
353-
354- 通过避免使用大型输入数组 ` x_mesh ` 和 ` y_mesh ` ,这个 ` vmap ` 版本使用的内存少得多,运行时间变化不大。
355-
356- 这很好——但我们还有进一步提升速度的空间!
357-
358- 首先请注意,上面的代码计算了完整的二维数组 ` f(x,y) ` ,这会产生开销,然后再取最大值。
359-
360- 其次,` jnp.max ` 调用位于 JIT 编译函数 ` f ` 之外,因此编译器无法将这些操作融合为单个内核。
361-
362- 我们可以通过将最大值操作移到内部并将所有内容包装在一个 ` @jax.jit ` 中来解决这两个问题:
363-
364332``` {code-cell} ipython3
365333@jax.jit
366334def compute_max_vmap(grid):
367- # 构建一个沿每行取最大值的函数
335+ # 构建一个对给定 y,在所有 x 上取最大值的函数
368336 f_vec_x_max = lambda y: jnp.max(f(grid, y))
369- # 向量化该函数,以便我们可以同时对所有行调用
337+ # 向量化该函数,以便我们可以同时对所有 y 调用
370338 f_vec_max = jax.vmap(f_vec_x_max)
371- # 调用向量化函数并取最大值
372- return jnp.max(f_vec_max(grid))
339+ # 在每个 y 处计算所有 x 上的最大值
340+ maxes = f_vec_max(grid)
341+ # 计算最大值的最大值并返回
342+ return jnp.max(maxes)
373343```
374344
375- 其中
376-
377- * ` f_vec_x_max ` 计算任意给定行的最大值
378- * ` f_vec_max ` 是一个向量化版本,可以并行计算所有行的最大值。
379-
380- 我们将此函数应用于所有行,然后取各行最大值中的最大值。
345+ 注意我们从不创建
381346
382- 由于将最大值操作移到内部,我们永远不会构建完整的二维数组 ` f(x,y) ` ,从而节省了更多内存。
347+ * 二维网格 ` x_mesh `
348+ * 二维网格 ` y_mesh ` 或
349+ * 二维数组 ` f(x,y) `
383350
384351并且由于所有内容都在单个 ` @jax.jit ` 下,编译器可以将所有操作融合为一个优化的内核。
385352
@@ -461,21 +428,70 @@ with qe.Timer():
461428
462429Numba 非常高效地处理了这个顺序运算。
463430
464- 注意,JIT 编译完成后,第二次运行明显更快。
431+ ### JAX 版本
465432
466- Numba 的编译通常相当快,对于像这样的顺序运算,生成的代码性能非常出色 。
433+ 我们不能直接用 ` jax.jit ` 替换 ` numba.jit ` ,因为 JAX 数组是不可变的 。
467434
468- ### JAX 版本
435+ 但我们仍然可以实现这一运算。
436+
437+ #### 第一种尝试
469438
470- 现在让我们使用 ` lax.scan ` 创建一个 JAX 版本:
439+ 以下是使用 ` at[t].set ` 语法的变通方案,我们在 {ref} ` JAX 讲座中讨论过 <jax_at_workaround> ` 。
471440
472- (我们将 ` n ` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。)
441+ 我们将应用 ` lax.fori_loop ` ,这是一种可以被 XLA 编译的 for 循环版本。
473442
474443``` {code-cell} ipython3
475444cpu = jax.devices("cpu")[0]
476445
477- @partial(jax.jit, static_argnames=('n',), device=cpu)
478- def qm_jax(x0, n, α=4.0):
446+ @partial(jax.jit, static_argnames=("n",), device=cpu)
447+ def qm_jax_fori(x0, n, α=4.0):
448+
449+ x = jnp.empty(n + 1).at[0].set(x0)
450+
451+ def update(t, x):
452+ return x.at[t + 1].set(α * x[t] * (1 - x[t]))
453+
454+ x = lax.fori_loop(0, n, update, x)
455+ return x
456+
457+ ```
458+
459+ * 我们将 ` n ` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。
460+ * 我们通过 ` device=cpu ` 将计算固定到 CPU,因为这种顺序工作负载由许多小型运算组成,几乎没有机会利用 GPU 并行性。
461+
462+ 重要提示:虽然 ` at[t].set ` 看起来在每一步都创建了一个新数组,但在 JIT 编译的函数内部,编译器会检测到旧数组不再需要,并就地执行更新!
463+
464+ 让我们使用相同的参数计时:
465+
466+ ``` {code-cell} ipython3
467+ with qe.Timer():
468+ # First run
469+ x_jax = qm_jax_fori(0.1, n)
470+ # Hold interpreter
471+ x_jax.block_until_ready()
472+ ```
473+
474+ 让我们再次运行以消除编译开销:
475+
476+ ``` {code-cell} ipython3
477+ with qe.Timer():
478+ # Second run
479+ x_jax = qm_jax_fori(0.1, n)
480+ # Hold interpreter
481+ x_jax.block_until_ready()
482+ ```
483+
484+ JAX 对于这种顺序运算也相当高效!
485+
486+ #### 第二种尝试
487+
488+ 还有另一种使用 ` lax.scan ` 实现该循环的方式。
489+
490+ 这种替代方案可以说更符合 JAX 的函数式风格——尽管语法难以记忆。
491+
492+ ``` {code-cell} ipython3
493+ @partial(jax.jit, static_argnames=("n",), device=cpu)
494+ def qm_jax_scan(x0, n, α=4.0):
479495 def update(x, t):
480496 x_new = α * x * (1 - x)
481497 return x_new, x_new
@@ -486,16 +502,12 @@ def qm_jax(x0, n, α=4.0):
486502
487503这段代码不易阅读,但本质上,` lax.scan ` 反复调用 ` update ` 并将返回值 ` x_new ` 累积到一个数组中。
488504
489- ``` {note}
490- 我们在 `jax.jit` 装饰器中指定了 `device=cpu`,因为该计算由许多小的顺序运算组成,几乎没有机会让 GPU 利用并行性。因此,GPU 上的内核启动开销往往占主导地位,使得 CPU 更适合这种工作负载。
491- ```
492-
493505让我们使用相同的参数计时:
494506
495507``` {code-cell} ipython3
496508with qe.Timer():
497509 # First run
498- x_jax = qm_jax (0.1, n)
510+ x_jax = qm_jax_scan (0.1, n)
499511 # Hold interpreter
500512 x_jax.block_until_ready()
501513```
@@ -505,26 +517,24 @@ with qe.Timer():
505517``` {code-cell} ipython3
506518with qe.Timer():
507519 # Second run
508- x_jax = qm_jax (0.1, n)
520+ x_jax = qm_jax_scan (0.1, n)
509521 # Hold interpreter
510522 x_jax.block_until_ready()
511523```
512524
513- JAX 对于这种顺序运算也相当高效。
514-
515- JAX 和 Numba 在编译后都能提供出色的性能。
525+ 令人惊讶的是,JAX 在编译后也能提供出色的性能。
516526
517527### 总结
518528
519- 虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,* 在代码可读性和易用性方面存在显著差异 * 。
529+ 虽然 Numba 和 JAX 在顺序运算中都能提供出色的性能,但在代码可读性和易用性方面存在差异 。
520530
521531Numba 版本简单直观,易于阅读:我们只需分配一个数组,然后使用标准 Python 循环逐元素填充它。
522532
523533这正是大多数程序员思考该算法的方式。
524534
525- 另一方面,JAX 版本需要使用 ` lax.scan ` ,这明显不够直观 。
535+ 另一方面,JAX 版本需要使用 ` lax.fori_loop ` 或 ` lax. scan` ,这两者都不如标准 Python 循环直观 。
526536
527- 此外, JAX 的不可变数组意味着我们无法简单地就地更新数组元素,这使得直接复制 Numba 使用的算法变得困难 。
537+ 虽然 JAX 的 ` at[t].set ` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读 。
528538
529539对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。
530540
0 commit comments