Skip to content

Commit 2989842

Browse files
committed
Update translation: lectures/jax_intro.md
1 parent 6f36a65 commit 2989842

1 file changed

Lines changed: 87 additions & 173 deletions

File tree

lectures/jax_intro.md

Lines changed: 87 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,15 @@ translation:
2121
JAX as a NumPy Replacement::Differences::Size Experiment: 大小实验
2222
JAX as a NumPy Replacement::Differences::Precision: 精度
2323
JAX as a NumPy Replacement::Differences::Immutability: 不可变性
24-
JAX as a NumPy Replacement::Differences::A workaround: 变通方法
24+
JAX as a NumPy Replacement::Differences::A Workaround: 变通方法
2525
Functional Programming: 函数式编程
2626
Functional Programming::Pure functions: 纯函数
27-
Functional Programming::Examples: 示例
28-
Functional Programming::Why Functional Programming?: 为什么使用函数式编程
27+
Functional Programming::Examples -- Pure and Impure: 示例——纯函数与非纯函数
28+
Functional Programming::Why Functional Programming?: 为什么要函数式编程
2929
Random numbers: 随机数
30-
Random numbers::Random number generation: 随机数生成
31-
Random numbers::Why explicit random state?: 为什么要显式随机状态?
32-
Random numbers::Why explicit random state?::NumPy's approach: NumPy 的方法
33-
Random numbers::Why explicit random state?::JAX's approach: JAX 的方法
30+
Random numbers::NumPy / MATLAB Approach: NumPy / MATLAB 方法
31+
Random numbers::JAX: JAX
32+
Random numbers::Benefits: 优势
3433
JIT Compilation: JIT 编译
3534
JIT Compilation::With NumPy: 使用 NumPy
3635
JIT Compilation::With JAX: 使用 JAX
@@ -343,19 +342,20 @@ a
343342
* 不会改变全局状态
344343
* 不会修改传递给函数的数据(不可变数据)
345344

346-
### 示例
345+
### 示例——纯函数与非纯函数
347346

348347
以下是一个*非纯*函数的示例:
349348

350349
```{code-cell} ipython3
351350
tax_rate = 0.1
352-
prices = [10.0, 20.0]
353351
354352
def add_tax(prices):
355353
for i, price in enumerate(prices):
356354
prices[i] = price * (1 + tax_rate)
357-
print('Post-tax prices: ', prices)
358-
return prices
355+
356+
prices = [10.0, 20.0]
357+
add_tax(prices)
358+
prices
359359
```
360360

361361
这个函数不是纯函数,因为:
@@ -366,15 +366,21 @@ def add_tax(prices):
366366
以下是一个**版本:
367367

368368
```{code-cell} ipython3
369-
tax_rate = 0.1
370-
prices = (10.0, 20.0)
371369
372370
def add_tax_pure(prices, tax_rate):
373371
new_prices = [price * (1 + tax_rate) for price in prices]
374372
return new_prices
373+
374+
tax_rate = 0.1
375+
prices = (10.0, 20.0)
376+
after_tax_prices = add_tax_pure(prices, tax_rate)
377+
after_tax_prices
375378
```
376379

377-
这个纯版本通过函数参数使所有依赖关系变得明确,并且不修改任何外部状态。
380+
这是纯函数,因为:
381+
382+
* 所有依赖关系通过函数参数显式表达
383+
* 不修改任何外部状态
378384

379385
### 为什么要函数式编程?
380386

@@ -404,13 +410,29 @@ JAX 使用函数式编程风格,以便用户构建的函数能够直接映射
404410

405411
JAX 中的随机数生成与 NumPy 或 MATLAB 中的模式有很大不同。
406412

407-
起初,您可能会觉得语法相当冗长。
413+
### NumPy / MATLAB 方法
408414

409-
但为了维护我们刚刚讨论的函数式编程风格,这种语法和语义是必要的
415+
在 NumPy / MATLAB 中,生成通过维护隐藏的全局状态来工作
410416

411-
此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。
417+
```{code-cell} ipython3
418+
np.random.seed(42)
419+
print(np.random.randn(2))
420+
```
412421

413-
### 随机数生成
422+
每次我们调用随机函数时,隐藏状态都会被更新:
423+
424+
```{code-cell} ipython3
425+
print(np.random.randn(2))
426+
```
427+
428+
这个函数*不是纯函数*,因为:
429+
430+
* 它是非确定性的:相同的输入,不同的输出
431+
* 它有副作用:它修改了全局随机数生成器状态
432+
433+
这在并行化下是危险的——必须仔细控制每个线程中发生的事情。
434+
435+
### JAX
414436

415437
在 JAX 中,随机数生成器的状态被显式控制。
416438

@@ -525,125 +547,58 @@ plt.show()
525547
下面的函数使用 `split` 生成 `k` 个(准)独立的随机 `n x n` 矩阵。
526548

527549
```{code-cell} ipython3
528-
def gen_random_matrices(key, n=2, k=3):
550+
def gen_random_matrices(
551+
key, # JAX key for random numbers
552+
n=2, # Matrices will be n x n
553+
k=3 # Number of matrices to generate
554+
):
529555
matrices = []
530556
for _ in range(k):
531557
key, subkey = jax.random.split(key)
532558
A = jax.random.uniform(subkey, (n, n))
533559
matrices.append(A)
534-
print(A)
535560
return matrices
536561
```
537562

538563
```{code-cell} ipython3
539564
seed = 42
540565
key = jax.random.key(seed)
541-
matrices = gen_random_matrices(key)
542-
```
543-
544-
我们也可以在循环迭代时使用 `fold_in`
545-
546-
```{code-cell} ipython3
547-
def gen_random_matrices(key, n=2, k=3):
548-
matrices = []
549-
for i in range(k):
550-
step_key = jax.random.fold_in(key, i)
551-
A = jax.random.uniform(step_key, (n, n))
552-
matrices.append(A)
553-
print(A)
554-
return matrices
566+
gen_random_matrices(key)
555567
```
556568

557-
```{code-cell} ipython3
558-
key = jax.random.key(seed)
559-
matrices = gen_random_matrices(key)
560-
```
561-
562-
### 为什么要显式随机状态?
563-
564-
为什么 JAX 需要这种相对冗长的随机数生成方法?
565-
566-
一个原因是为了维护纯函数。
567-
568-
让我们通过比较 NumPy 和 JAX 来看看随机数生成与纯函数的关系。
569-
570-
#### NumPy 的方法
571-
572-
在 NumPy 的旧版随机数生成 API(模仿 MATLAB)中,生成通过维护隐藏的全局状态来工作。
573-
574-
每次我们调用随机函数时,这个状态都会被更新:
575-
576-
```{code-cell} ipython3
577-
np.random.seed(42)
578-
print(np.random.randn()) # Updates state of random number generator
579-
print(np.random.randn()) # Updates state of random number generator
580-
```
581-
582-
每次调用都返回不同的值,即使我们用相同的输入(没有参数)调用相同的函数。
583-
584-
这个函数*不是纯函数*,因为:
585-
586-
* 它是非确定性的:相同的输入(在这种情况下,没有输入)产生不同的输出
587-
* 它有副作用:它修改了全局随机数生成器状态
588-
589-
#### JAX 的方法
590-
591-
如上所示,JAX 采用了不同的方法,通过密钥使随机性显式化。
592-
593-
例如:
594-
595-
```{code-cell} ipython3
596-
def random_sum_jax(key):
597-
key1, key2 = jax.random.split(key)
598-
x = jax.random.normal(key1)
599-
y = jax.random.normal(key2)
600-
return x + y
601-
```
602-
603-
使用相同的密钥,我们总是得到相同的结果:
604-
605-
```{code-cell} ipython3
606-
key = jax.random.key(42)
607-
random_sum_jax(key)
608-
```
569+
这个函数是*纯函数*
609570

610-
```{code-cell} ipython3
611-
random_sum_jax(key)
612-
```
613-
614-
要获得新的抽取,我们需要提供一个新密钥。
615-
616-
函数 `random_sum_jax` 是纯函数,因为:
617-
618-
* 它是确定性的:相同的密钥总是产生相同的输出
571+
* 确定性:相同的输入,相同的输出
619572
* 无副作用:没有隐藏状态被修改
620573

621-
JAX 的显式性带来了显著的好处:
574+
### 优势
575+
576+
如上所述,这种显式性是有价值的:
622577

623578
* 可复现性:通过重用密钥轻松重现结果
624-
* 并行化:每个线程可以拥有自己的密钥而不会产生冲突
625-
* 调试:没有隐藏状态使代码更容易推理
579+
* 并行化:控制每个独立线程中发生的事情
580+
* 调试:没有隐藏状态使代码更容易测试
626581
* JIT 兼容性:编译器可以更积极地优化纯函数
627582

628-
最后一点将在下一节中进行扩展。
629-
630583
## JIT 编译
631584

632585
JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。
633586

634587
我们在 {ref}`上文 <jax_speed>` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 `cos` 函数。
635588

636-
让我们用一个更复杂的函数尝试同样的操作:
589+
这里我们研究针对更复杂函数的 JIT 编译。
590+
591+
### 使用 NumPy
592+
593+
我们先用 NumPy 试试:
637594

638595
```{code-cell}
639596
def f(x):
640597
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
641598
return y
642599
```
643600

644-
### 使用 NumPy
645-
646-
我们先用 NumPy 试试:
601+
让我们用较大的 `x` 运行:
647602

648603
```{code-cell}
649604
n = 50_000_000
@@ -656,9 +611,17 @@ with qe.Timer():
656611
y = f(x)
657612
```
658613

659-
### 使用 JAX
614+
**即时执行**模型
615+
616+
* 每个操作在遇到时立即执行,在下一个操作开始之前将其结果实体化。
660617

661-
现在让我们用 JAX 再试一次。
618+
缺点
619+
620+
* 并行化程度极低
621+
* 内存占用大——产生许多中间数组
622+
* 大量内存读写
623+
624+
### 使用 JAX
662625

663626
作为第一步,我们将整个代码中的 `np` 替换为 `jnp`
664627

@@ -691,11 +654,20 @@ with qe.Timer():
691654

692655
结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。
693656

694-
然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。
657+
这是因为单个数组操作在 GPU 上并行化了。
658+
659+
但我们仍然在使用即时执行模式
660+
661+
* 由于中间数组导致大量内存占用
662+
* 大量内存读写
663+
664+
此外,在 GPU 上还启动了许多独立的内核。
695665

696666
### 编译整个函数
697667

698-
JAX 即时(JIT)编译器可以通过将数组运算融合到单个优化内核中来加速函数内部的执行。
668+
幸运的是,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。
669+
670+
编译器将所有数组操作融合到单个优化内核中。
699671

700672
让我们用函数 `f` 来试试这个:
701673

@@ -719,9 +691,12 @@ with qe.Timer():
719691
jax.block_until_ready(y);
720692
```
721693

722-
运行时间再次改善——现在是因为我们融合了所有操作,使编译器能够更积极地进行优化。
694+
运行时间再次改善——现在是因为我们融合了所有操作
723695

724-
例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。
696+
* 基于整个计算序列的激进优化
697+
* 消除对硬件加速器的多次调用
698+
699+
内存占用也大幅降低——不再创建中间数组。
725700

726701
顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是:
727702

@@ -741,11 +716,9 @@ XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TP
741716

742717
### 编译非纯函数
743718

744-
现在我们已经看到了 JIT 编译的强大之处,理解它与纯函数的关系非常重要。
745-
746-
虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测。
719+
虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测!
747720

748-
以下是一个使用全局变量的例子
721+
以下是一个示例
749722

750723
```{code-cell} ipython3
751724
a = 1 # global
@@ -787,65 +760,6 @@ f(x)
787760

788761
这个故事的寓意:使用 JAX 时请编写纯函数!
789762

790-
## 使用 `vmap` 进行向量化
791-
792-
JAX 的另一个强大变换是 `jax.vmap`,它能自动将一个针对单个输入编写的函数向量化,使其可以在批量数据上运行。
793-
794-
这避免了手动编写向量化代码或使用显式循环的需要。
795-
796-
### 一个简单的示例
797-
798-
假设我们有一个函数,用于计算一组数字的均值与中位数之差。
799-
800-
```{code-cell} ipython3
801-
def mm_diff(x):
802-
return jnp.mean(x) - jnp.median(x)
803-
```
804-
805-
我们可以将其应用于单个向量:
806-
807-
```{code-cell} ipython3
808-
x = jnp.array([1.0, 2.0, 5.0])
809-
mm_diff(x)
810-
```
811-
812-
现在假设我们有一个矩阵,想要对每一行计算这些统计量。
813-
814-
不使用 `vmap` 时,我们需要显式循环:
815-
816-
```{code-cell} ipython3
817-
X = jnp.array([[1.0, 2.0, 5.0],
818-
[4.0, 5.0, 6.0],
819-
[1.0, 8.0, 9.0]])
820-
821-
for row in X:
822-
print(mm_diff(row))
823-
```
824-
825-
然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。
826-
827-
使用 `vmap` 可以将计算保留在加速器上,并与其他 JAX 变换(如 `jit``grad`)组合使用:
828-
829-
```{code-cell} ipython3
830-
batch_mm_diff = jax.vmap(mm_diff)
831-
batch_mm_diff(X)
832-
```
833-
834-
函数 `mm_diff` 是针对单个数组编写的,而 `vmap` 自动将其提升为按行作用于矩阵的函数——无需循环,无需重新塑形。
835-
836-
### 组合变换
837-
838-
JAX 的优势之一在于各变换可以自然地组合使用。
839-
840-
例如,我们可以对向量化函数进行 JIT 编译:
841-
842-
```{code-cell} ipython3
843-
fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff))
844-
fast_batch_mm_diff(X)
845-
```
846-
847-
`jit``vmap` 以及(我们接下来将看到的)`grad` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
848-
849763
## 练习
850764

851765

0 commit comments

Comments
 (0)