Skip to content

Commit 3ea9f1a

Browse files
committed
misc
1 parent e02d1d3 commit 3ea9f1a

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ diagonal.
153153

154154
To trick NumPy into calculating `f(x,y)` on every `x,y` pair, we need to use `np.meshgrid`.
155155

156-
Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y`
157-
156+
Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y`
158157
such that `f(x, y)` generates all evaluations on the product grid.
159158

160159

@@ -180,7 +179,7 @@ print(f"NumPy result: {z_max_numpy:.6f}")
180179

181180
### Memory Issues
182181

183-
So we have the right solution reasonable time --- but memory usage is huge.
182+
So we have the right solution in reasonable time --- but memory usage is huge.
184183

185184
While the flat arrays are low-memory
186185

@@ -337,7 +336,7 @@ The compilation overhead is a one-time cost that pays off when the function is c
337336

338337
### JAX plus vmap
339338

340-
Because we use'd `jax.jit` above, we avoided creating many intermediate arrays.
339+
Because we used `jax.jit` above, we avoided creating many intermediate arrays.
341340

342341
But we still create the big arrays `z_max`, `x_mesh`, and `y_mesh`.
343342

0 commit comments

Comments
 (0)