Skip to content

Commit e02d1d3

Browse files
committed
misc
1 parent 450bafe commit e02d1d3

2 files changed

Lines changed: 101 additions & 74 deletions

File tree

lectures/jax_intro.md

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -351,19 +351,20 @@ In particular, pure functions will always return the same result if invoked with
351351

352352

353353

354-
### Examples
354+
### Examples -- Pure and Impure
355355

356-
Here's an example of a *non-pure* function
356+
Here's an example of a *impure* function
357357

358358
```{code-cell} ipython3
359359
tax_rate = 0.1
360-
prices = [10.0, 20.0]
361360
362361
def add_tax(prices):
363362
for i, price in enumerate(prices):
364363
prices[i] = price * (1 + tax_rate)
365-
print('Post-tax prices: ', prices)
366-
return prices
364+
365+
prices = [10.0, 20.0]
366+
add_tax(prices)
367+
prices
367368
```
368369

369370
This function fails to be pure because
@@ -375,15 +376,22 @@ This function fails to be pure because
375376
Here's a *pure* version
376377

377378
```{code-cell} ipython3
378-
tax_rate = 0.1
379-
prices = (10.0, 20.0)
380379
381380
def add_tax_pure(prices, tax_rate):
382381
new_prices = [price * (1 + tax_rate) for price in prices]
383382
return new_prices
383+
384+
tax_rate = 0.1
385+
prices = (10.0, 20.0)
386+
after_tax_prices = add_tax(prices)
387+
after_tax_prices
384388
```
385389

386-
This pure version makes all dependencies explicit through function arguments, and doesn't modify any external state.
390+
This is pure because
391+
392+
* all dependencies explicit through function arguments
393+
* and doesn't modify any external state
394+
387395

388396
### Why Functional Programming?
389397

@@ -438,8 +446,8 @@ This function is *not pure* because:
438446
* It's non-deterministic: same inputs, different outputs
439447
* It has side effects: it modifies the global random number generator state
440448

441-
Dangerous under parallelization --- must carefully control what happens in each
442-
thread!
449+
This is dangerous under parallelization --- must carefully control what happens in each
450+
thread.
443451

444452

445453
### JAX
@@ -560,7 +568,11 @@ sense when we get to parallel programming.
560568
The function below produces `k` (quasi-) independent random `n x n` matrices using `split`.
561569

562570
```{code-cell} ipython3
563-
def gen_random_matrices(key, n=2, k=3):
571+
def gen_random_matrices(
572+
key, # JAX key for random numbers
573+
n=2, # Matrices will be n x n
574+
k=3 # Number of matrices to generate
575+
):
564576
matrices = []
565577
for _ in range(k):
566578
key, subkey = jax.random.split(key)
@@ -583,7 +595,7 @@ This function is *pure*
583595

584596
### Benefits
585597

586-
The explicitness of JAX brings significant benefits:
598+
As mentioned above, this explicitness is valuable:
587599

588600
* Reproducibility: Easy to reproduce results by reusing keys
589601
* Parallelization: Control what happens on separate threads
@@ -672,8 +684,14 @@ with qe.Timer():
672684
The outcome is similar to the `cos` example --- JAX is faster, especially on the
673685
second run after JIT compilation.
674686

675-
But we are still using eager execution --- lots of memory and read/write
687+
This is because the individual array operations are parallelized on the GPU
676688

689+
But we are still using eager execution
690+
691+
* lots of memory due to intermediate arrays
692+
* lots of memory read/writes
693+
694+
Also, many separate kernels launched on the GPU
677695

678696
### Compiling the Whole Function
679697

@@ -708,7 +726,8 @@ The runtime has improved again --- now because we fused all the operations
708726

709727
* Aggressive optimization based on entire computational sequence
710728
* Eliminates multiple calls to the hardware accelerator
711-
* No creation of intermediate arrays
729+
730+
The memory footprint is also much lower --- no creation of intermediate arrays
712731

713732
Incidentally, a more common syntax when targeting a function for the JIT
714733
compiler is

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 68 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,37 @@ for x in grid:
135135

136136
Let's switch to NumPy and use a larger grid
137137

138-
Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` such
139-
that `f(x, y)` generates all evaluations on the product grid.
138+
```{code-cell} ipython3
139+
grid = np.linspace(-3, 3, 3_000) # Large grid
140+
```
141+
142+
As a first pass of vectorization we might try something like this
143+
144+
```{code-cell} ipython3
145+
# Large grid
146+
z = np.max(f(grid, grid)) # This is wrong!
147+
```
148+
149+
The problem here is that `f(grid, grid)` doesn't obey the nested loop.
150+
151+
In terms of the figure above, it only computes the values of `f` along the
152+
diagonal.
153+
154+
To trick NumPy into calculating `f(x,y)` on every `x,y` pair, we need to use `np.meshgrid`.
155+
156+
Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y`
157+
158+
such that `f(x, y)` generates all evaluations on the product grid.
140159

141160

142161
```{code-cell} ipython3
143162
# Large grid
144163
grid = np.linspace(-3, 3, 3_000)
145164
146-
x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
165+
x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid
147166
148167
with qe.Timer():
149-
z_max_numpy = np.max(f(x, y))
168+
z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works
150169
```
151170

152171
In the vectorized version, all the looping takes place in compiled code.
@@ -159,11 +178,30 @@ The output should be close to one:
159178
print(f"NumPy result: {z_max_numpy:.6f}")
160179
```
161180

181+
### Memory Issues
182+
183+
So we have the right solution reasonable time --- but memory usage is huge.
184+
185+
While the flat arrays are low-memory
186+
187+
```{code-cell} ipython3
188+
grid.nbytes
189+
```
190+
191+
the mesh grids are two-dimensional and hence very memory intensive
192+
193+
```{code-cell} ipython3
194+
x_mesh.nbytes + y_mesh.nbytes
195+
```
196+
197+
Moreover, NumPy's eager execution creates many intermediate arrays of the same size!
198+
199+
This kind of memory usage can be a big problem in actual research calculations.
162200

163201

164202
### A Comparison with Numba
165203

166-
Now let's see if we can achieve better performance using Numba with a simple loop.
204+
Let's see if we can achieve better performance using Numba with a simple loop.
167205

168206
```{code-cell} ipython3
169207
@numba.jit
@@ -194,15 +232,13 @@ with qe.Timer():
194232
compute_max_numba(grid)
195233
```
196234

197-
Depending on your machine, the Numba version might be either slower or faster than NumPy.
235+
Notice how we are using almost no memory --- we just need the one-dimensional `grid`
198236

199-
In most cases we find that Numba is slightly better.
237+
Moreover, execution speed is good.
200238

201-
On the one hand, NumPy combines efficient arithmetic with some
202-
multithreading, which provides an advantage.
239+
On most machines, the Numba version will be somewhat faster than NumPy.
203240

204-
On the other hand, the Numba routine uses much less memory, since we are only
205-
working with a single one-dimensional grid.
241+
The reason is efficient machine code plus less memory read-write.
206242

207243

208244
### Parallelized Numba
@@ -301,27 +337,11 @@ The compilation overhead is a one-time cost that pays off when the function is c
301337

302338
### JAX plus vmap
303339

304-
There is one problem with both the NumPy code and the JAX code above:
305-
306-
While the flat arrays are low-memory
307-
308-
```{code-cell} ipython3
309-
grid.nbytes
310-
```
311-
312-
the mesh grids are memory intensive
313-
314-
```{code-cell} ipython3
315-
x_mesh.nbytes + y_mesh.nbytes
316-
```
340+
Because we use'd `jax.jit` above, we avoided creating many intermediate arrays.
317341

318-
This extra memory usage can be a big problem in actual research calculations.
342+
But we still create the big arrays `z_max`, `x_mesh`, and `y_mesh`.
319343

320-
Fortunately, JAX admits a different approach
321-
using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html).
322-
323-
The idea of `vmap` is to break vectorization into stages, transforming a
324-
function that operates on single values into one that operates on arrays.
344+
Fortunately, we can avoid this by using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html).
325345

326346
Here's how we can apply it to our problem.
327347

@@ -330,13 +350,13 @@ Here's how we can apply it to our problem.
330350
@jax.jit
331351
def compute_max_vmap(grid):
332352
# Construct a function that takes the max over all x for given y
333-
f_vec_x_max = lambda y: jnp.max(f(grid, y))
353+
compute_column_max = lambda y: jnp.max(f(grid, y))
334354
# Vectorize the function so we can call on all y simultaneously
335-
f_vec_max = jax.vmap(f_vec_x_max)
336-
# Compute the max across x at every y
337-
maxes = f_vec_max(grid)
338-
# Compute the max of the maxes and return
339-
return jnp.max(maxes)
355+
vectorized_compute_column_max = jax.vmap(compute_column_max)
356+
# Compute the column max at every row
357+
column_maxes = vectorized_compute_column_max(grid)
358+
# Compute the max of the column maxes and return
359+
return jnp.max(column_maxes)
340360
```
341361

342362
Note that we never create
@@ -345,6 +365,8 @@ Note that we never create
345365
* the two-dimensional grid `y_mesh` or
346366
* the two-dimensional array `f(x,y)`
347367

368+
Like Numba, we just use the flat array `grid`.
369+
348370
And because everything is under a single `@jax.jit`, the compiler can fuse
349371
all operations into one optimized kernel.
350372

@@ -378,18 +400,14 @@ In our view, JAX is the winner for vectorized operations.
378400
It dominates NumPy both in terms of speed (via JIT-compilation and
379401
parallelization) and memory efficiency (via vmap).
380402

381-
Moreover, the `vmap` approach can sometimes lead to significantly clearer code.
382-
383-
While Numba is impressive, the beauty of JAX is that, with fully vectorized
384-
operations, we can run exactly the same code on machines with hardware
385-
accelerators and reap all the benefits without extra effort.
403+
It also dominates Numba when run on the GPU.
386404

387-
Moreover, JAX already knows how to effectively parallelize many common array
388-
operations, which is key to fast execution.
389-
390-
For most cases encountered in economics, econometrics, and finance, it is
405+
```{note}
406+
Numba can support GPU programming through `numba.cuda` but then we need to
407+
parallelize by hand. For most cases encountered in economics, econometrics, and finance, it is
391408
far better to hand over to the JAX compiler for efficient parallelization than to
392-
try to hand code these routines ourselves.
409+
try to hand-code these routines ourselves.
410+
```
393411

394412

395413
## Sequential operations
@@ -554,8 +572,6 @@ The JAX versions, on the other hand, require either `lax.fori_loop` or
554572
While JAX's `at[t].set` syntax does allow element-wise updates, the overall code
555573
remains harder to read than the Numba equivalent.
556574

557-
For this type of sequential operation, Numba is the clear winner in terms of
558-
code clarity and ease of implementation.
559575

560576

561577
## Overall recommendations
@@ -573,25 +589,17 @@ than traditional meshgrid-based vectorization.
573589
In addition, JAX functions are automatically differentiable, as we explore in
574590
{doc}`autodiff`.
575591

576-
For **sequential operations**, Numba has clear advantages.
592+
For **sequential operations**, Numba has nicer syntax.
577593

578594
The code is natural and readable --- just a Python loop with a decorator ---
579595
and performance is excellent.
580596

581597
JAX can handle sequential problems via `lax.fori_loop` or `lax.scan`, but
582598
the syntax is less intuitive.
583599

584-
```{note}
585-
One important advantage of `lax.fori_loop` and `lax.scan` is that they
586-
support automatic differentiation through the loop, which Numba cannot do.
587-
If you need to differentiate through a sequential computation (e.g., computing
588-
sensitivities of a trajectory to model parameters), JAX is the better choice
589-
despite the less natural syntax.
590-
```
600+
On the other hand, the JAX versions support automatic differentiation.
591601

592-
In practice, many problems involve a mix of both patterns.
602+
That might be of interest if, say, we want to compute sensitivities of a
603+
trajectory to model parameters
593604

594-
A good rule of thumb: default to JAX for new projects, especially when
595-
hardware acceleration or differentiability might be useful, and reach for Numba
596-
when you have a tight sequential loop that needs to be fast and readable.
597605

0 commit comments

Comments
 (0)