Skip to content

Commit 86a5c13

Browse files
authored
🌐 [translation-sync] Improve and simplify JAX lectures (#57)
* Update translation: lectures/jax_intro.md * Update translation: .translate/state/jax_intro.md.yml * Update translation: lectures/numpy_vs_numba_vs_jax.md * Update translation: .translate/state/numpy_vs_numba_vs_jax.md.yml * Add missing (jax_at_workaround)= anchor in jax_intro.md The cross-reference {ref}`jax_at_workaround` in numpy_vs_numba_vs_jax.md exists in upstream but the anchor was dropped during translation. Restoring it so the Jupyter Book build resolves the reference. * Disambiguate 急切 (eager) vs 即时 (JIT) execution Per Copilot translation review: 即时 is used elsewhere in this file for JIT compilation, so using 即时 again for 'eager execution' was ambiguous. Use 急切 (Eager) with English parenthetical for clarity.
1 parent 6f36a65 commit 86a5c13

4 files changed

Lines changed: 169 additions & 203 deletions

File tree

.translate/state/jax_intro.md.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
2-
synced-at: "2026-04-13"
1+
source-sha: 450bafecd23db638602150b47f4272b98aad3146
2+
synced-at: "2026-04-14"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 7

.translate/state/numpy_vs_numba_vs_jax.md.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f
2-
synced-at: "2026-04-13"
1+
source-sha: 450bafecd23db638602150b47f4272b98aad3146
2+
synced-at: "2026-04-14"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 3

lectures/jax_intro.md

Lines changed: 70 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,24 @@ translation:
2121
JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验
2222
JAX as a NumPy Replacement::Differences::Precision: 精度
2323
JAX as a NumPy Replacement::Differences::Immutability: 不可变性
24-
JAX as a NumPy Replacement::Differences::A workaround: 变通方法
24+
JAX as a NumPy Replacement::Differences::A Workaround: 变通方法
2525
Functional Programming: 函数式编程
2626
Functional Programming::Pure functions: 纯函数
2727
Functional Programming::Examples: 示例
28-
Functional Programming::Why Functional Programming?: 为什么使用函数式编程
28+
Functional Programming::Why Functional Programming?: 为什么要函数式编程
2929
Random numbers: 随机数
30-
Random numbers::Random number generation: 随机数生成
31-
Random numbers::Why explicit random state?: 为什么要显式随机状态?
32-
Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法
33-
Random numbers::Why explicit random state?::JAX's approach: JAX 的方法
30+
Random numbers::NumPy / MATLAB Approach: NumPy / MATLAB 方法
31+
Random numbers::JAX: JAX
32+
Random numbers::Benefits: 好处
3433
JIT Compilation: JIT 编译
3534
JIT Compilation::With NumPy: 使用 NumPy
3635
JIT Compilation::With JAX: 使用 JAX
3736
JIT Compilation::Compiling the Whole Function: 编译整个函数
3837
JIT Compilation::How JIT compilation works: JIT 编译的工作原理
3938
JIT Compilation::Compiling non-pure functions: 编译非纯函数
39+
Vectorization with `vmap`: 使用 `vmap` 进行向量化
40+
Vectorization with `vmap`::A simple example: 一个简单的示例
41+
Vectorization with `vmap`::Combining transformations: 组合变换
4042
Exercises: 练习
4143
---
4244

@@ -290,6 +292,7 @@ except Exception as e:
290292

291293
JAX 的设计者选择将数组设为不可变的,因为 JAX 使用函数式编程风格,我们将在下面讨论这一点。
292294

295+
(jax_at_workaround)=
293296
#### 变通方法
294297

295298
我们注意到 JAX 确实提供了一种替代原地数组修改的方式,使用 [`at` 方法](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html)
@@ -404,13 +407,29 @@ JAX 使用函数式编程风格,以便用户构建的函数能够直接映射
404407

405408
JAX 中的随机数生成与 NumPy 或 MATLAB 中的模式有很大不同。
406409

407-
起初,您可能会觉得语法相当冗长。
410+
### NumPy / MATLAB 方法
408411

409-
但为了维护我们刚刚讨论的函数式编程风格,这种语法和语义是必要的
412+
在 NumPy / MATLAB 中,生成通过维护隐藏的全局状态来工作
410413

411-
此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。
414+
```{code-cell} ipython3
415+
np.random.seed(42)
416+
print(np.random.randn(2))
417+
```
418+
419+
每次我们调用随机函数时,隐藏状态都会被更新:
420+
421+
```{code-cell} ipython3
422+
print(np.random.randn(2))
423+
```
424+
425+
这个函数*不是纯函数*,因为:
412426

413-
### 随机数生成
427+
* 它是非确定性的:相同的输入,不同的输出
428+
* 它有副作用:它修改了全局随机数生成器状态
429+
430+
在并行化下很危险——必须仔细控制每个线程中发生的事情!
431+
432+
### JAX
414433

415434
在 JAX 中,随机数生成器的状态被显式控制。
416435

@@ -531,119 +550,48 @@ def gen_random_matrices(key, n=2, k=3):
531550
key, subkey = jax.random.split(key)
532551
A = jax.random.uniform(subkey, (n, n))
533552
matrices.append(A)
534-
print(A)
535553
return matrices
536554
```
537555

538556
```{code-cell} ipython3
539557
seed = 42
540558
key = jax.random.key(seed)
541-
matrices = gen_random_matrices(key)
542-
```
543-
544-
我们也可以在循环迭代时使用 `fold_in`
545-
546-
```{code-cell} ipython3
547-
def gen_random_matrices(key, n=2, k=3):
548-
matrices = []
549-
for i in range(k):
550-
step_key = jax.random.fold_in(key, i)
551-
A = jax.random.uniform(step_key, (n, n))
552-
matrices.append(A)
553-
print(A)
554-
return matrices
555-
```
556-
557-
```{code-cell} ipython3
558-
key = jax.random.key(seed)
559-
matrices = gen_random_matrices(key)
560-
```
561-
562-
### 为什么要显式随机状态?
563-
564-
为什么 JAX 需要这种相对冗长的随机数生成方法?
565-
566-
一个原因是为了维护纯函数。
567-
568-
让我们通过比较 NumPy 和 JAX 来看看随机数生成与纯函数的关系。
569-
570-
#### NumPy 的方法
571-
572-
在 NumPy 的旧版随机数生成 API(模仿 MATLAB)中,生成通过维护隐藏的全局状态来工作。
573-
574-
每次我们调用随机函数时,这个状态都会被更新:
575-
576-
```{code-cell} ipython3
577-
np.random.seed(42)
578-
print(np.random.randn()) # Updates state of random number generator
579-
print(np.random.randn()) # Updates state of random number generator
559+
gen_random_matrices(key)
580560
```
581561

582-
每次调用都返回不同的值,即使我们用相同的输入(没有参数)调用相同的函数。
562+
这个函数是*纯函数*
583563

584-
这个函数*不是纯函数*,因为:
585-
586-
* 它是非确定性的:相同的输入(在这种情况下,没有输入)产生不同的输出
587-
* 它有副作用:它修改了全局随机数生成器状态
588-
589-
#### JAX 的方法
590-
591-
如上所示,JAX 采用了不同的方法,通过密钥使随机性显式化。
592-
593-
例如:
594-
595-
```{code-cell} ipython3
596-
def random_sum_jax(key):
597-
key1, key2 = jax.random.split(key)
598-
x = jax.random.normal(key1)
599-
y = jax.random.normal(key2)
600-
return x + y
601-
```
602-
603-
使用相同的密钥,我们总是得到相同的结果:
604-
605-
```{code-cell} ipython3
606-
key = jax.random.key(42)
607-
random_sum_jax(key)
608-
```
609-
610-
```{code-cell} ipython3
611-
random_sum_jax(key)
612-
```
613-
614-
要获得新的抽取,我们需要提供一个新密钥。
615-
616-
函数 `random_sum_jax` 是纯函数,因为:
617-
618-
* 它是确定性的:相同的密钥总是产生相同的输出
564+
* 确定性的:相同的输入,相同的输出
619565
* 无副作用:没有隐藏状态被修改
620566

567+
### 好处
568+
621569
JAX 的显式性带来了显著的好处:
622570

623571
* 可复现性:通过重用密钥轻松重现结果
624-
* 并行化:每个线程可以拥有自己的密钥而不会产生冲突
625-
* 调试:没有隐藏状态使代码更容易推理
572+
* 并行化:控制各个线程上发生的事情
573+
* 调试:没有隐藏状态使代码更容易测试
626574
* JIT 兼容性:编译器可以更积极地优化纯函数
627575

628-
最后一点将在下一节中进行扩展。
629-
630576
## JIT 编译
631577

632578
JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。
633579

634580
我们在 {ref}`上文 <jax_speed>` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 `cos` 函数。
635581

636-
让我们用一个更复杂的函数尝试同样的操作:
582+
这里我们研究更复杂函数的 JIT 编译。
583+
584+
### 使用 NumPy
585+
586+
我们先用 NumPy 试试,使用:
637587

638588
```{code-cell}
639589
def f(x):
640590
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
641591
return y
642592
```
643593

644-
### 使用 NumPy
645-
646-
我们先用 NumPy 试试:
594+
用较大的 `x` 运行:
647595

648596
```{code-cell}
649597
n = 50_000_000
@@ -656,9 +604,17 @@ with qe.Timer():
656604
y = f(x)
657605
```
658606

659-
### 使用 JAX
607+
**急切**(Eager)执行模型
608+
609+
* 每个操作在遇到时立即执行,在下一个操作开始之前将其结果实体化。
610+
611+
缺点
612+
613+
* 并行化程度最低
614+
* 较大的内存占用——产生许多中间数组
615+
* 大量内存读写
660616

661-
现在让我们用 JAX 再试一次。
617+
### 使用 JAX
662618

663619
作为第一步,我们将整个代码中的 `np` 替换为 `jnp`
664620

@@ -691,11 +647,13 @@ with qe.Timer():
691647

692648
结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。
693649

694-
然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作
650+
但我们仍在使用即时执行——大量内存和读写开销
695651

696652
### 编译整个函数
697653

698-
JAX 即时(JIT)编译器可以通过将数组运算融合到单个优化内核中来加速函数内部的执行。
654+
幸运的是,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。
655+
656+
编译器将所有数组运算融合到单个优化内核中。
699657

700658
让我们用函数 `f` 来试试这个:
701659

@@ -719,9 +677,11 @@ with qe.Timer():
719677
jax.block_until_ready(y);
720678
```
721679

722-
运行时间再次改善——现在是因为我们融合了所有操作,使编译器能够更积极地进行优化。
680+
运行时间再次改善——现在是因为我们融合了所有操作
723681

724-
例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。
682+
* 基于整个计算序列的积极优化
683+
* 消除对硬件加速器的多次调用
684+
* 不创建中间数组
725685

726686
顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是:
727687

@@ -741,11 +701,9 @@ XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TP
741701

742702
### 编译非纯函数
743703

744-
现在我们已经看到了 JIT 编译的强大之处,理解它与纯函数的关系非常重要。
704+
虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测!
745705

746-
虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测。
747-
748-
以下是一个使用全局变量的例子:
706+
以下是一个例子:
749707

750708
```{code-cell} ipython3
751709
a = 1 # global
@@ -789,9 +747,9 @@ f(x)
789747

790748
## 使用 `vmap` 进行向量化
791749

792-
JAX 的另一个强大变换是 `jax.vmap`它能自动将一个针对单个输入编写的函数向量化,使其可以在批量数据上运行
750+
JAX 的另一个强大变换是 `jax.vmap`它能够自动将针对单个输入编写的函数向量化,使其可以对批量数据进行操作
793751

794-
这避免了手动编写向量化代码或使用显式循环的需要
752+
这样就无需手动编写向量化代码或使用显式循环
795753

796754
### 一个简单的示例
797755

@@ -809,7 +767,7 @@ x = jnp.array([1.0, 2.0, 5.0])
809767
mm_diff(x)
810768
```
811769

812-
现在假设我们有一个矩阵,想要对每一行计算这些统计量
770+
现在假设我们有一个矩阵,希望对每一行计算这些统计量
813771

814772
不使用 `vmap` 时,我们需要显式循环:
815773

@@ -824,18 +782,16 @@ for row in X:
824782

825783
然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。
826784

827-
使用 `vmap` 可以将计算保留在加速器上,并与其他 JAX 变换(如 `jit``grad`)组合使用
785+
使用 `vmap`,我们可以避免循环,并将计算保留在加速器上
828786

829787
```{code-cell} ipython3
830-
batch_mm_diff = jax.vmap(mm_diff)
831-
batch_mm_diff(X)
788+
batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version
789+
batch_mm_diff(X) # Apply to each row of X
832790
```
833791

834-
函数 `mm_diff` 是针对单个数组编写的,而 `vmap` 自动将其提升为按行作用于矩阵的函数——无需循环,无需重新塑形。
835-
836792
### 组合变换
837793

838-
JAX 的优势之一在于各变换可以自然地组合使用
794+
JAX 的优势之一在于各种变换可以自然地组合使用
839795

840796
例如,我们可以对向量化函数进行 JIT 编译:
841797

@@ -844,7 +800,7 @@ fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff))
844800
fast_batch_mm_diff(X)
845801
```
846802

847-
`jit``vmap` 以及(我们接下来将看到的)`grad` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
803+
`jit``vmap` 以及(我们接下来将看到的)`grad` 的这种组合是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
848804

849805
## 练习
850806

0 commit comments

Comments
 (0)