Skip to content

Commit 6e17ee1

Browse files
committed
Update translation: lectures/jax_intro.md
1 parent 6f36a65 commit 6e17ee1

1 file changed

Lines changed: 69 additions & 114 deletions

File tree

lectures/jax_intro.md

Lines changed: 69 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,24 @@ translation:
2121
JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验
2222
JAX as a NumPy Replacement::Differences::Precision: 精度
2323
JAX as a NumPy Replacement::Differences::Immutability: 不可变性
24-
JAX as a NumPy Replacement::Differences::A workaround: 变通方法
24+
JAX as a NumPy Replacement::Differences::A Workaround: 变通方法
2525
Functional Programming: 函数式编程
2626
Functional Programming::Pure functions: 纯函数
2727
Functional Programming::Examples: 示例
28-
Functional Programming::Why Functional Programming?: 为什么使用函数式编程
28+
Functional Programming::Why Functional Programming?: 为什么要函数式编程
2929
Random numbers: 随机数
30-
Random numbers::Random number generation: 随机数生成
31-
Random numbers::Why explicit random state?: 为什么要显式随机状态?
32-
Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法
33-
Random numbers::Why explicit random state?::JAX's approach: JAX 的方法
30+
Random numbers::NumPy / MATLAB Approach: NumPy / MATLAB 方法
31+
Random numbers::JAX: JAX
32+
Random numbers::Benefits: 好处
3433
JIT Compilation: JIT 编译
3534
JIT Compilation::With NumPy: 使用 NumPy
3635
JIT Compilation::With JAX: 使用 JAX
3736
JIT Compilation::Compiling the Whole Function: 编译整个函数
3837
JIT Compilation::How JIT compilation works: JIT 编译的工作原理
3938
JIT Compilation::Compiling non-pure functions: 编译非纯函数
39+
Vectorization with `vmap`: 使用 `vmap` 进行向量化
40+
Vectorization with `vmap`::A simple example: 一个简单的示例
41+
Vectorization with `vmap`::Combining transformations: 组合变换
4042
Exercises: 练习
4143
---
4244

@@ -404,13 +406,29 @@ JAX 使用函数式编程风格,以便用户构建的函数能够直接映射
404406

405407
JAX 中的随机数生成与 NumPy 或 MATLAB 中的模式有很大不同。
406408

407-
起初,您可能会觉得语法相当冗长。
409+
### NumPy / MATLAB 方法
408410

409-
但为了维护我们刚刚讨论的函数式编程风格,这种语法和语义是必要的
411+
在 NumPy / MATLAB 中,生成通过维护隐藏的全局状态来工作
410412

411-
此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。
413+
```{code-cell} ipython3
414+
np.random.seed(42)
415+
print(np.random.randn(2))
416+
```
417+
418+
每次我们调用随机函数时,隐藏状态都会被更新:
419+
420+
```{code-cell} ipython3
421+
print(np.random.randn(2))
422+
```
423+
424+
这个函数*不是纯函数*,因为:
412425

413-
### 随机数生成
426+
* 它是非确定性的:相同的输入,不同的输出
427+
* 它有副作用:它修改了全局随机数生成器状态
428+
429+
在并行化下很危险——必须仔细控制每个线程中发生的事情!
430+
431+
### JAX
414432

415433
在 JAX 中,随机数生成器的状态被显式控制。
416434

@@ -531,119 +549,48 @@ def gen_random_matrices(key, n=2, k=3):
531549
key, subkey = jax.random.split(key)
532550
A = jax.random.uniform(subkey, (n, n))
533551
matrices.append(A)
534-
print(A)
535552
return matrices
536553
```
537554

538555
```{code-cell} ipython3
539556
seed = 42
540557
key = jax.random.key(seed)
541-
matrices = gen_random_matrices(key)
542-
```
543-
544-
我们也可以在循环迭代时使用 `fold_in`
545-
546-
```{code-cell} ipython3
547-
def gen_random_matrices(key, n=2, k=3):
548-
matrices = []
549-
for i in range(k):
550-
step_key = jax.random.fold_in(key, i)
551-
A = jax.random.uniform(step_key, (n, n))
552-
matrices.append(A)
553-
print(A)
554-
return matrices
555-
```
556-
557-
```{code-cell} ipython3
558-
key = jax.random.key(seed)
559-
matrices = gen_random_matrices(key)
560-
```
561-
562-
### 为什么要显式随机状态?
563-
564-
为什么 JAX 需要这种相对冗长的随机数生成方法?
565-
566-
一个原因是为了维护纯函数。
567-
568-
让我们通过比较 NumPy 和 JAX 来看看随机数生成与纯函数的关系。
569-
570-
#### NumPy 的方法
571-
572-
在 NumPy 的旧版随机数生成 API(模仿 MATLAB)中,生成通过维护隐藏的全局状态来工作。
573-
574-
每次我们调用随机函数时,这个状态都会被更新:
575-
576-
```{code-cell} ipython3
577-
np.random.seed(42)
578-
print(np.random.randn()) # Updates state of random number generator
579-
print(np.random.randn()) # Updates state of random number generator
558+
gen_random_matrices(key)
580559
```
581560

582-
每次调用都返回不同的值,即使我们用相同的输入(没有参数)调用相同的函数。
561+
这个函数是*纯函数*
583562

584-
这个函数*不是纯函数*,因为:
585-
586-
* 它是非确定性的:相同的输入(在这种情况下,没有输入)产生不同的输出
587-
* 它有副作用:它修改了全局随机数生成器状态
588-
589-
#### JAX 的方法
590-
591-
如上所示,JAX 采用了不同的方法,通过密钥使随机性显式化。
592-
593-
例如:
594-
595-
```{code-cell} ipython3
596-
def random_sum_jax(key):
597-
key1, key2 = jax.random.split(key)
598-
x = jax.random.normal(key1)
599-
y = jax.random.normal(key2)
600-
return x + y
601-
```
602-
603-
使用相同的密钥,我们总是得到相同的结果:
604-
605-
```{code-cell} ipython3
606-
key = jax.random.key(42)
607-
random_sum_jax(key)
608-
```
609-
610-
```{code-cell} ipython3
611-
random_sum_jax(key)
612-
```
613-
614-
要获得新的抽取,我们需要提供一个新密钥。
615-
616-
函数 `random_sum_jax` 是纯函数,因为:
617-
618-
* 它是确定性的:相同的密钥总是产生相同的输出
563+
* 确定性的:相同的输入,相同的输出
619564
* 无副作用:没有隐藏状态被修改
620565

566+
### 好处
567+
621568
JAX 的显式性带来了显著的好处:
622569

623570
* 可复现性:通过重用密钥轻松重现结果
624-
* 并行化:每个线程可以拥有自己的密钥而不会产生冲突
625-
* 调试:没有隐藏状态使代码更容易推理
571+
* 并行化:控制各个线程上发生的事情
572+
* 调试:没有隐藏状态使代码更容易测试
626573
* JIT 兼容性:编译器可以更积极地优化纯函数
627574

628-
最后一点将在下一节中进行扩展。
629-
630575
## JIT 编译
631576

632577
JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。
633578

634579
我们在 {ref}`上文 <jax_speed>` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 `cos` 函数。
635580

636-
让我们用一个更复杂的函数尝试同样的操作:
581+
这里我们研究更复杂函数的 JIT 编译。
582+
583+
### 使用 NumPy
584+
585+
我们先用 NumPy 试试,使用:
637586

638587
```{code-cell}
639588
def f(x):
640589
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
641590
return y
642591
```
643592

644-
### 使用 NumPy
645-
646-
我们先用 NumPy 试试:
593+
用较大的 `x` 运行:
647594

648595
```{code-cell}
649596
n = 50_000_000
@@ -656,9 +603,17 @@ with qe.Timer():
656603
y = f(x)
657604
```
658605

659-
### 使用 JAX
606+
**即时**执行模型
607+
608+
* 每个操作在遇到时立即执行,在下一个操作开始之前将其结果实体化。
609+
610+
缺点
611+
612+
* 并行化程度最低
613+
* 较大的内存占用——产生许多中间数组
614+
* 大量内存读写
660615

661-
现在让我们用 JAX 再试一次。
616+
### 使用 JAX
662617

663618
作为第一步,我们将整个代码中的 `np` 替换为 `jnp`
664619

@@ -691,11 +646,13 @@ with qe.Timer():
691646

692647
结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。
693648

694-
然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作
649+
但我们仍在使用即时执行——大量内存和读写开销
695650

696651
### 编译整个函数
697652

698-
JAX 即时(JIT)编译器可以通过将数组运算融合到单个优化内核中来加速函数内部的执行。
653+
幸运的是,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。
654+
655+
编译器将所有数组运算融合到单个优化内核中。
699656

700657
让我们用函数 `f` 来试试这个:
701658

@@ -719,9 +676,11 @@ with qe.Timer():
719676
jax.block_until_ready(y);
720677
```
721678

722-
运行时间再次改善——现在是因为我们融合了所有操作,使编译器能够更积极地进行优化。
679+
运行时间再次改善——现在是因为我们融合了所有操作
723680

724-
例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。
681+
* 基于整个计算序列的积极优化
682+
* 消除对硬件加速器的多次调用
683+
* 不创建中间数组
725684

726685
顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是:
727686

@@ -741,11 +700,9 @@ XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TP
741700

742701
### 编译非纯函数
743702

744-
现在我们已经看到了 JIT 编译的强大之处,理解它与纯函数的关系非常重要。
703+
虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测!
745704

746-
虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测。
747-
748-
以下是一个使用全局变量的例子:
705+
以下是一个例子:
749706

750707
```{code-cell} ipython3
751708
a = 1 # global
@@ -789,9 +746,9 @@ f(x)
789746

790747
## 使用 `vmap` 进行向量化
791748

792-
JAX 的另一个强大变换是 `jax.vmap`它能自动将一个针对单个输入编写的函数向量化,使其可以在批量数据上运行
749+
JAX 的另一个强大变换是 `jax.vmap`它能够自动将针对单个输入编写的函数向量化,使其可以对批量数据进行操作
793750

794-
这避免了手动编写向量化代码或使用显式循环的需要
751+
这样就无需手动编写向量化代码或使用显式循环
795752

796753
### 一个简单的示例
797754

@@ -809,7 +766,7 @@ x = jnp.array([1.0, 2.0, 5.0])
809766
mm_diff(x)
810767
```
811768

812-
现在假设我们有一个矩阵,想要对每一行计算这些统计量
769+
现在假设我们有一个矩阵,希望对每一行计算这些统计量
813770

814771
不使用 `vmap` 时,我们需要显式循环:
815772

@@ -824,18 +781,16 @@ for row in X:
824781

825782
然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。
826783

827-
使用 `vmap` 可以将计算保留在加速器上,并与其他 JAX 变换(如 `jit``grad`)组合使用
784+
使用 `vmap`,我们可以避免循环,并将计算保留在加速器上
828785

829786
```{code-cell} ipython3
830-
batch_mm_diff = jax.vmap(mm_diff)
831-
batch_mm_diff(X)
787+
batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version
788+
batch_mm_diff(X) # Apply to each row of X
832789
```
833790

834-
函数 `mm_diff` 是针对单个数组编写的,而 `vmap` 自动将其提升为按行作用于矩阵的函数——无需循环,无需重新塑形。
835-
836791
### 组合变换
837792

838-
JAX 的优势之一在于各变换可以自然地组合使用
793+
JAX 的优势之一在于各种变换可以自然地组合使用
839794

840795
例如,我们可以对向量化函数进行 JIT 编译:
841796

@@ -844,7 +799,7 @@ fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff))
844799
fast_batch_mm_diff(X)
845800
```
846801

847-
`jit``vmap` 以及(我们接下来将看到的)`grad` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
802+
`jit``vmap` 以及(我们接下来将看到的)`grad` 的这种组合是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
848803

849804
## 练习
850805

0 commit comments

Comments
 (0)