@@ -21,22 +21,24 @@ translation:
2121 JAX as a NumPy Replacement::Differences::Size Experiment : 大小实验
2222 JAX as a NumPy Replacement::Differences::Precision : 精度
2323 JAX as a NumPy Replacement::Differences::Immutability : 不可变性
24- JAX as a NumPy Replacement::Differences::A workaround : 变通方法
24+ JAX as a NumPy Replacement::Differences::A Workaround : 变通方法
2525 Functional Programming : 函数式编程
2626 Functional Programming::Pure functions : 纯函数
2727 Functional Programming::Examples : 示例
28- Functional Programming::Why Functional Programming? : 为什么使用函数式编程 ?
28+ Functional Programming::Why Functional Programming? : 为什么要函数式编程 ?
2929 Random numbers : 随机数
30- Random numbers::Random number generation : 随机数生成
31- Random numbers::Why explicit random state? : 为什么要显式随机状态?
32- Random numbers::Why explicit random state?::NumPy's approach : NumPy 的方法
33- Random numbers::Why explicit random state?::JAX's approach : JAX 的方法
30+ Random numbers::NumPy / MATLAB Approach : NumPy / MATLAB 方法
31+ Random numbers::JAX : JAX
32+ Random numbers::Benefits : 好处
3433 JIT Compilation : JIT 编译
3534 JIT Compilation::With NumPy : 使用 NumPy
3635 JIT Compilation::With JAX : 使用 JAX
3736 JIT Compilation::Compiling the Whole Function : 编译整个函数
3837 JIT Compilation::How JIT compilation works : JIT 编译的工作原理
3938 JIT Compilation::Compiling non-pure functions : 编译非纯函数
39+ Vectorization with `vmap` : 使用 `vmap` 进行向量化
40+ Vectorization with `vmap`::A simple example : 一个简单的示例
41+ Vectorization with `vmap`::Combining transformations : 组合变换
4042 Exercises : 练习
4143---
4244
@@ -404,13 +406,29 @@ JAX 使用函数式编程风格,以便用户构建的函数能够直接映射
404406
405407JAX 中的随机数生成与 NumPy 或 MATLAB 中的模式有很大不同。
406408
407- 起初,您可能会觉得语法相当冗长。
409+ ### NumPy / MATLAB 方法
408410
409- 但为了维护我们刚刚讨论的函数式编程风格,这种语法和语义是必要的 。
411+ 在 NumPy / MATLAB 中,生成通过维护隐藏的全局状态来工作 。
410412
411- 此外,对随机状态的完全控制对于并行编程至关重要,例如当我们想要沿多个线程运行独立实验时。
413+ ``` {code-cell} ipython3
414+ np.random.seed(42)
415+ print(np.random.randn(2))
416+ ```
417+
418+ 每次我们调用随机函数时,隐藏状态都会被更新:
419+
420+ ``` {code-cell} ipython3
421+ print(np.random.randn(2))
422+ ```
423+
424+ 这个函数* 不是纯函数* ,因为:
412425
413- ### 随机数生成
426+ * 它是非确定性的:相同的输入,不同的输出
427+ * 它有副作用:它修改了全局随机数生成器状态
428+
429+ 在并行化下很危险——必须仔细控制每个线程中发生的事情!
430+
431+ ### JAX
414432
415433在 JAX 中,随机数生成器的状态被显式控制。
416434
@@ -531,119 +549,48 @@ def gen_random_matrices(key, n=2, k=3):
531549 key, subkey = jax.random.split(key)
532550 A = jax.random.uniform(subkey, (n, n))
533551 matrices.append(A)
534- print(A)
535552 return matrices
536553```
537554
538555``` {code-cell} ipython3
539556seed = 42
540557key = jax.random.key(seed)
541- matrices = gen_random_matrices(key)
542- ```
543-
544- 我们也可以在循环迭代时使用 ` fold_in ` :
545-
546- ``` {code-cell} ipython3
547- def gen_random_matrices(key, n=2, k=3):
548- matrices = []
549- for i in range(k):
550- step_key = jax.random.fold_in(key, i)
551- A = jax.random.uniform(step_key, (n, n))
552- matrices.append(A)
553- print(A)
554- return matrices
555- ```
556-
557- ``` {code-cell} ipython3
558- key = jax.random.key(seed)
559- matrices = gen_random_matrices(key)
560- ```
561-
562- ### 为什么要显式随机状态?
563-
564- 为什么 JAX 需要这种相对冗长的随机数生成方法?
565-
566- 一个原因是为了维护纯函数。
567-
568- 让我们通过比较 NumPy 和 JAX 来看看随机数生成与纯函数的关系。
569-
570- #### NumPy 的方法
571-
572- 在 NumPy 的旧版随机数生成 API(模仿 MATLAB)中,生成通过维护隐藏的全局状态来工作。
573-
574- 每次我们调用随机函数时,这个状态都会被更新:
575-
576- ``` {code-cell} ipython3
577- np.random.seed(42)
578- print(np.random.randn()) # Updates state of random number generator
579- print(np.random.randn()) # Updates state of random number generator
558+ gen_random_matrices(key)
580559```
581560
582- 每次调用都返回不同的值,即使我们用相同的输入(没有参数)调用相同的函数。
561+ 这个函数是 * 纯函数 * :
583562
584- 这个函数* 不是纯函数* ,因为:
585-
586- * 它是非确定性的:相同的输入(在这种情况下,没有输入)产生不同的输出
587- * 它有副作用:它修改了全局随机数生成器状态
588-
589- #### JAX 的方法
590-
591- 如上所示,JAX 采用了不同的方法,通过密钥使随机性显式化。
592-
593- 例如:
594-
595- ``` {code-cell} ipython3
596- def random_sum_jax(key):
597- key1, key2 = jax.random.split(key)
598- x = jax.random.normal(key1)
599- y = jax.random.normal(key2)
600- return x + y
601- ```
602-
603- 使用相同的密钥,我们总是得到相同的结果:
604-
605- ``` {code-cell} ipython3
606- key = jax.random.key(42)
607- random_sum_jax(key)
608- ```
609-
610- ``` {code-cell} ipython3
611- random_sum_jax(key)
612- ```
613-
614- 要获得新的抽取,我们需要提供一个新密钥。
615-
616- 函数 ` random_sum_jax ` 是纯函数,因为:
617-
618- * 它是确定性的:相同的密钥总是产生相同的输出
563+ * 确定性的:相同的输入,相同的输出
619564* 无副作用:没有隐藏状态被修改
620565
566+ ### 好处
567+
621568JAX 的显式性带来了显著的好处:
622569
623570* 可复现性:通过重用密钥轻松重现结果
624- * 并行化:每个线程可以拥有自己的密钥而不会产生冲突
625- * 调试:没有隐藏状态使代码更容易推理
571+ * 并行化:控制各个线程上发生的事情
572+ * 调试:没有隐藏状态使代码更容易测试
626573* JIT 兼容性:编译器可以更积极地优化纯函数
627574
628- 最后一点将在下一节中进行扩展。
629-
630575## JIT 编译
631576
632577JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。
633578
634579我们在 {ref}` 上文 <jax_speed> ` 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 ` cos ` 函数。
635580
636- 让我们用一个更复杂的函数尝试同样的操作:
581+ 这里我们研究更复杂函数的 JIT 编译。
582+
583+ ### 使用 NumPy
584+
585+ 我们先用 NumPy 试试,使用:
637586
638587``` {code-cell}
639588def f(x):
640589 y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
641590 return y
642591```
643592
644- ### 使用 NumPy
645-
646- 我们先用 NumPy 试试:
593+ 用较大的 ` x ` 运行:
647594
648595``` {code-cell}
649596n = 50_000_000
@@ -656,9 +603,17 @@ with qe.Timer():
656603 y = f(x)
657604```
658605
659- ### 使用 JAX
606+ ** 即时** 执行模型
607+
608+ * 每个操作在遇到时立即执行,在下一个操作开始之前将其结果实体化。
609+
610+ 缺点
611+
612+ * 并行化程度最低
613+ * 较大的内存占用——产生许多中间数组
614+ * 大量内存读写
660615
661- 现在让我们用 JAX 再试一次。
616+ ### 使用 JAX
662617
663618作为第一步,我们将整个代码中的 ` np ` 替换为 ` jnp ` :
664619
@@ -691,11 +646,13 @@ with qe.Timer():
691646
692647结果与 ` cos ` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。
693648
694- 然而,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作 。
649+ 但我们仍在使用即时执行——大量内存和读写开销 。
695650
696651### 编译整个函数
697652
698- JAX 即时(JIT)编译器可以通过将数组运算融合到单个优化内核中来加速函数内部的执行。
653+ 幸运的是,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。
654+
655+ 编译器将所有数组运算融合到单个优化内核中。
699656
700657让我们用函数 ` f ` 来试试这个:
701658
@@ -719,9 +676,11 @@ with qe.Timer():
719676 jax.block_until_ready(y);
720677```
721678
722- 运行时间再次改善——现在是因为我们融合了所有操作,使编译器能够更积极地进行优化。
679+ 运行时间再次改善——现在是因为我们融合了所有操作:
723680
724- 例如,编译器可以消除对硬件加速器的多次调用以及许多中间数组的创建。
681+ * 基于整个计算序列的积极优化
682+ * 消除对硬件加速器的多次调用
683+ * 不创建中间数组
725684
726685顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是:
727686
@@ -741,11 +700,9 @@ XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TP
741700
742701### 编译非纯函数
743702
744- 现在我们已经看到了 JIT 编译的强大之处,理解它与纯函数的关系非常重要。
703+ 虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测!
745704
746- 虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测。
747-
748- 以下是一个使用全局变量的例子:
705+ 以下是一个例子:
749706
750707``` {code-cell} ipython3
751708a = 1 # global
789746
790747## 使用 ` vmap ` 进行向量化
791748
792- JAX 的另一个强大变换是 ` jax.vmap ` ,它能自动将一个针对单个输入编写的函数向量化,使其可以在批量数据上运行 。
749+ JAX 的另一个强大变换是 ` jax.vmap ` ,它能够自动将针对单个输入编写的函数向量化,使其可以对批量数据进行操作 。
793750
794- 这避免了手动编写向量化代码或使用显式循环的需要 。
751+ 这样就无需手动编写向量化代码或使用显式循环 。
795752
796753### 一个简单的示例
797754
@@ -809,7 +766,7 @@ x = jnp.array([1.0, 2.0, 5.0])
809766mm_diff(x)
810767```
811768
812- 现在假设我们有一个矩阵,想要对每一行计算这些统计量 。
769+ 现在假设我们有一个矩阵,希望对每一行计算这些统计量 。
813770
814771不使用 ` vmap ` 时,我们需要显式循环:
815772
@@ -824,18 +781,16 @@ for row in X:
824781
825782然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。
826783
827- 使用 ` vmap ` 可以将计算保留在加速器上,并与其他 JAX 变换(如 ` jit ` 和 ` grad ` )组合使用 :
784+ 使用 ` vmap ` ,我们可以避免循环,并将计算保留在加速器上 :
828785
829786``` {code-cell} ipython3
830- batch_mm_diff = jax.vmap(mm_diff)
831- batch_mm_diff(X)
787+ batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version
788+ batch_mm_diff(X) # Apply to each row of X
832789```
833790
834- 函数 ` mm_diff ` 是针对单个数组编写的,而 ` vmap ` 自动将其提升为按行作用于矩阵的函数——无需循环,无需重新塑形。
835-
836791### 组合变换
837792
838- JAX 的优势之一在于各变换可以自然地组合使用 。
793+ JAX 的优势之一在于各种变换可以自然地组合使用 。
839794
840795例如,我们可以对向量化函数进行 JIT 编译:
841796
@@ -844,7 +799,7 @@ fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff))
844799fast_batch_mm_diff(X)
845800```
846801
847- ` jit ` 、` vmap ` 以及(我们接下来将看到的)` grad ` 的这种组合方式是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
802+ ` jit ` 、` vmap ` 以及(我们接下来将看到的)` grad ` 的这种组合是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。
848803
849804## 练习
850805
0 commit comments