| jupytext |
|
||||||
|---|---|---|---|---|---|---|---|
| kernelspec |
|
(numpy_numba_jax)=
<div id="qe-notebook-header" align="right" style="text-align:right;">
<a href="https://quantecon.org/" title="quantecon.org">
<img style="width:250px;display:inline;" width="250px" src="https://assets.quantecon.org/img/qe-menubar-logo.svg" alt="QuantEcon">
</a>
</div>
In the preceding lectures, we've discussed three core libraries for scientific and numerical computing:
Which one should we use in any given situation?
This lecture addresses that question, at least partially, by discussing some use cases.
Before getting started, we note that the first two are a natural pair: NumPy and Numba play well together.
JAX, on the other hand, stands alone.
When considering each approach, we will consider not just efficiency and memory footprint but also clarity and ease of use.
In addition to what's in Anaconda, this lecture will need the following libraries:
---
tags: [hide-output]
---
!pip install quantecon jax
We will use the following imports.
import random
from functools import partial
import numpy as np
import numba
import quantecon as qe
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
from matplotlib import cm
import jax
import jax.numpy as jnp
from jax import lax
Some operations can be perfectly vectorized --- all loops are easily eliminated and numerical operations are reduced to calculations on arrays.
In this case, which approach is best?
Consider the problem of maximizing a function
For
Here's a plot of
def f(x, y):
return np.cos(x**2 + y**2) / (1 + x**2 + y**2)
xgrid = np.linspace(-3, 3, 50)
ygrid = xgrid
x, y = np.meshgrid(xgrid, ygrid)
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(x,
y,
f(x, y),
rstride=2, cstride=2,
cmap=cm.viridis,
alpha=0.7,
linewidth=0.25)
ax.set_zlim(-0.5, 1.0)
ax.set_xlabel('$x$', fontsize=14)
ax.set_ylabel('$y$', fontsize=14)
plt.show()
For the sake of this exercise, we're going to use brute force for the maximization.
- Evaluate
$f$ for all$(x,y)$ in a grid on the square. - Return the maximum of observed values.
Just to illustrate the idea, here's a non-vectorized version that uses Python loops.
grid = np.linspace(-3, 3, 50)
m = -np.inf
for x in grid:
for y in grid:
z = f(x, y)
if z > m:
m = z
If we switch to NumPy-style vectorization we can use a much larger grid and the code executes relatively quickly.
Here we use np.meshgrid to create two-dimensional input grids x and y such
that f(x, y) generates all evaluations on the product grid.
(This strategy dates back to MATLAB.)
grid = np.linspace(-3, 3, 3_000)
x, y = np.meshgrid(grid, grid)
with qe.Timer(precision=8):
z_max_numpy = np.max(f(x, y))
print(f"NumPy result: {z_max_numpy:.6f}")
In the vectorized version, all the looping takes place in compiled code.
Moreover, NumPy uses implicit multithreading, so that at least some parallelization occurs.
(The parallelization cannot be highly efficient because the binary is compiled
before it sees the size of the arrays x and y.)
Now let's see if we can achieve better performance using Numba with a simple loop.
@numba.jit
def compute_max_numba(grid):
m = -np.inf
for x in grid:
for y in grid:
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
if z > m:
m = z
return m
grid = np.linspace(-3, 3, 3_000)
with qe.Timer(precision=8):
z_max_numba = compute_max_numba(grid)
print(f"Numba result: {z_max_numba:.6f}")
Let's run again to eliminate compile time.
with qe.Timer(precision=8):
compute_max_numba(grid)
Depending on your machine, the Numba version can be a bit slower or a bit faster than NumPy.
On one hand, NumPy combines efficient arithmetic (like Numba) with some multithreading (unlike this Numba code), which provides an advantage.
On the other hand, the Numba routine uses much less memory, since we are only working with a single one-dimensional grid.
Now let's try parallelization with Numba using prange:
Here's a naive and incorrect attempt.
@numba.jit(parallel=True)
def compute_max_numba_parallel(grid):
n = len(grid)
m = -np.inf
for i in numba.prange(n):
for j in range(n):
x = grid[i]
y = grid[j]
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
if z > m:
m = z
return m
This returns -inf --- the initial value of m, as if it were never updated:
z_max_parallel_incorrect = compute_max_numba_parallel(grid)
print(f"Numba result: {z_max_parallel_incorrect} 😱")
To understand why, recall that prange splits the outer loop across threads.
Each thread gets its own private copy of m, initialized to -np.inf, and
correctly updates it within its chunk of iterations.
But at the end of the loop, Numba needs to combine the per-thread copies of m
back into a single value --- a reduction.
For patterns it recognizes, such as m += z (sum) or m = max(m, z) (max),
Numba knows the combining operator.
But it does not recognize the if z > m: m = z pattern as a max reduction, so
the per-thread results are never combined and m retains its initial value.
The simplest fix is to replace the conditional with max, which Numba
recognizes:
@numba.jit(parallel=True)
def compute_max_numba_parallel(grid):
n = len(grid)
m = -np.inf
for i in numba.prange(n):
for j in range(n):
x = grid[i]
y = grid[j]
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
m = max(m, z)
return m
An alternative is to make the loop body fully independent across i and
handle the reduction ourselves:
@numba.jit(parallel=True)
def compute_max_numba_parallel_v2(grid):
n = len(grid)
row_maxes = np.empty(n)
for i in numba.prange(n):
row_max = -np.inf
for j in range(n):
x = grid[i]
y = grid[j]
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
if z > row_max:
row_max = z
row_maxes[i] = row_max
return np.max(row_maxes)
Here each thread writes to a separate element of row_maxes, so we handle the
reduction ourselves via np.max.
z_max_parallel = compute_max_numba_parallel(grid)
print(f"Numba result: {z_max_parallel:.6f}")
Here's the timing.
with qe.Timer(precision=8):
compute_max_numba_parallel(grid)
If you have multiple cores, you should see at least some benefits from parallelization here.
For more powerful machines and larger grid sizes, parallelization can generate major speed gains, even on the CPU.
On the surface, vectorized code in JAX is similar to NumPy code.
But there are also some differences, which we highlight here.
Let's start with the function.
@jax.jit
def f(x, y):
return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2)
As with NumPy, to get the right shape and the correct nested for loop
calculation, we can use a meshgrid operation designed for this purpose:
grid = jnp.linspace(-3, 3, 3_000)
x_mesh, y_mesh = jnp.meshgrid(grid, grid)
with qe.Timer(precision=8):
z_max = jnp.max(f(x_mesh, y_mesh))
z_max.block_until_ready()
print(f"Plain vanilla JAX result: {z_max:.6f}")
Let's run again to eliminate compile time.
with qe.Timer(precision=8):
z_max = jnp.max(f(x_mesh, y_mesh))
z_max.block_until_ready()
Once compiled, JAX is significantly faster than NumPy, especially on a GPU.
The compilation overhead is a one-time cost that pays off when the function is called repeatedly.
There is one problem with both the NumPy code and the JAX code above:
While the flat arrays are low-memory
grid.nbytes
the mesh grids are memory intensive
x_mesh.nbytes + y_mesh.nbytes
This extra memory usage can be a big problem in actual research calculations.
Fortunately, JAX admits a different approach using jax.vmap.
The idea of vmap is to break vectorization into stages, transforming a
function that operates on single values into one that operates on arrays.
Here's how we can apply it to our problem.
# Set up f to compute f(x, y) at every x for any given y
f_vec_x = lambda y: f(grid, y)
# Create a second function that vectorizes this operation over all y
f_vec = jax.vmap(f_vec_x)
Now f_vec will compute f(x,y) at every x,y when called with the flat array grid.
Let's see the timing:
with qe.Timer(precision=8):
z_max = jnp.max(f_vec(grid))
z_max.block_until_ready()
print(f"JAX vmap v1 result: {z_max:.6f}")
with qe.Timer(precision=8):
z_max = jnp.max(f_vec(grid))
z_max.block_until_ready()
By avoiding the large input arrays x_mesh and y_mesh, this vmap version
uses far less memory, with similar runtime.
But we are still leaving speed gains on the table.
The code above computes the full two-dimensional array f(x,y) and then takes
the max.
Moreover, the jnp.max call sits outside the JIT-compiled function f, so the
compiler cannot fuse these operations into a single kernel.
We can fix both problems by pushing the max inside and wrapping everything in
a single @jax.jit:
@jax.jit
def compute_max_vmap(grid):
# Construct a function that takes the max along each row
f_vec_x_max = lambda y: jnp.max(f(grid, y))
# Vectorize the function so we can call on all rows simultaneously
f_vec_max = jax.vmap(f_vec_x_max)
# Call the vectorized function and take the max
return jnp.max(f_vec_max(grid))
Here
f_vec_x_maxcomputes the max along any given rowf_vec_maxis a vectorized version that can compute the max of all rows in parallel.
We apply this function to all rows and then take the max of the row maxes.
Because we push the max inside, we never construct the full two-dimensional
array f(x,y), saving even more memory.
And because everything is under a single @jax.jit, the compiler can fuse
all operations into one optimized kernel.
Let's try it.
with qe.Timer(precision=8):
z_max = compute_max_vmap(grid).block_until_ready()
print(f"JAX vmap result: {z_max:.6f}")
Let's run it again to eliminate compilation time:
with qe.Timer(precision=8):
z_max = compute_max_vmap(grid).block_until_ready()
In our view, JAX is the winner for vectorized operations.
It dominates NumPy both in terms of speed (via JIT-compilation and parallelization) and memory efficiency (via vmap).
Moreover, the vmap approach can sometimes lead to significantly clearer code.
While Numba is impressive, the beauty of JAX is that, with fully vectorized operations, we can run exactly the same code on machines with hardware accelerators and reap all the benefits without extra effort.
Moreover, JAX already knows how to effectively parallelize many common array operations, which is key to fast execution.
For most cases encountered in economics, econometrics, and finance, it is far better to hand over to the JAX compiler for efficient parallelization than to try to hand code these routines ourselves.
Some operations are inherently sequential -- and hence difficult or impossible to vectorize.
In this case NumPy is a poor option and we are left with the choice of Numba or JAX.
To compare these choices, we will revisit the problem of iterating on the
quadratic map that we saw in our {doc}Numba lecture <numba>.
Here's the Numba version.
@numba.jit
def qm(x0, n, α=4.0):
x = np.empty(n+1)
x[0] = x0
for t in range(n):
x[t+1] = α * x[t] * (1 - x[t])
return x
Let's generate a time series of length 10,000,000 and time the execution:
n = 10_000_000
with qe.Timer(precision=8):
x = qm(0.1, n)
Let's run it again to eliminate compilation time:
with qe.Timer(precision=8):
x = qm(0.1, n)
Numba handles this sequential operation very efficiently.
Notice that the second run is significantly faster after JIT compilation completes.
Numba's compilation is typically quite fast, and the resulting code performance is excellent for sequential operations like this one.
Now let's create a JAX version using lax.scan:
(We'll hold n static because it affects array size and hence JAX wants to specialize on its value in the compiled code.)
cpu = jax.devices("cpu")[0]
@partial(jax.jit, static_argnums=(1,), device=cpu)
def qm_jax(x0, n, α=4.0):
def update(x, t):
x_new = α * x * (1 - x)
return x_new, x_new
_, x = lax.scan(update, x0, jnp.arange(n))
return jnp.concatenate([jnp.array([x0]), x])
This code is not easy to read but, in essence, lax.scan repeatedly calls update and accumulates the returns x_new into an array.
Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.
The computation consists of many small sequential operations, leaving little
opportunity for the GPU to exploit parallelism.
As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU
a better fit for this workload.
Curious readers can try removing this option to see how performance changes.
Let's time it with the same parameters:
with qe.Timer(precision=8):
x_jax = qm_jax(0.1, n).block_until_ready()
Let's run it again to eliminate compilation overhead:
with qe.Timer(precision=8):
x_jax = qm_jax(0.1, n).block_until_ready()
JAX is also quite efficient for this sequential operation.
Both JAX and Numba deliver strong performance after compilation.
While both Numba and JAX deliver strong performance for sequential operations, there are significant differences in code readability and ease of use.
The Numba version is straightforward and natural to read: we simply allocate an array and fill it element by element using a standard Python loop.
This is exactly how most programmers think about the algorithm.
The JAX version, on the other hand, requires using lax.scan, which is significantly less intuitive.
Additionally, JAX's immutable arrays mean we cannot simply update array elements in place, making it hard to directly replicate the algorithm used by Numba.
For this type of sequential operation, Numba is the clear winner in terms of code clarity and ease of implementation.
Let's now step back and summarize the trade-offs.
For vectorized operations, JAX is the strongest choice.
It matches or exceeds NumPy in speed, thanks to JIT compilation and efficient parallelization across CPUs and GPUs.
The vmap transformation reduces memory usage and often leads to clearer code
than traditional meshgrid-based vectorization.
In addition, JAX functions are automatically differentiable, as we explore in
{doc}autodiff.
For sequential operations, Numba has clear advantages.
The code is natural and readable --- just a Python loop with a decorator --- and performance is excellent.
JAX can handle sequential problems via lax.scan, but the syntax is less
intuitive.
One important advantage of `lax.scan` is that it supports automatic
differentiation through the loop, which Numba cannot do.
If you need to differentiate through a sequential computation (e.g., computing
sensitivities of a trajectory to model parameters), JAX is the better choice
despite the less natural syntax.
In practice, many problems involve a mix of both patterns.
A good rule of thumb: default to JAX for new projects, especially when hardware acceleration or differentiability might be useful, and reach for Numba when you have a tight sequential loop that needs to be fast and readable.