@@ -135,18 +135,37 @@ for x in grid:
135135
136136Let's switch to NumPy and use a larger grid
137137
138- Here we use ` np.meshgrid ` to create two-dimensional input grids ` x ` and ` y ` such
139- that ` f(x, y) ` generates all evaluations on the product grid.
138+ ``` {code-cell} ipython3
139+ grid = np.linspace(-3, 3, 3_000) # Large grid
140+ ```
141+
142+ As a first pass of vectorization we might try something like this
143+
144+ ``` {code-cell} ipython3
145+ # Large grid
146+ z = np.max(f(grid, grid)) # This is wrong!
147+ ```
148+
149+ The problem here is that ` f(grid, grid) ` doesn't obey the nested loop.
150+
151+ In terms of the figure above, it only computes the values of ` f ` along the
152+ diagonal.
153+
154+ To trick NumPy into calculating ` f(x,y) ` on every ` x,y ` pair, we need to use ` np.meshgrid ` .
155+
156+ Here we use ` np.meshgrid ` to create two-dimensional input grids ` x ` and ` y `
157+
158+ such that ` f(x, y) ` generates all evaluations on the product grid.
140159
141160
142161``` {code-cell} ipython3
143162# Large grid
144163grid = np.linspace(-3, 3, 3_000)
145164
146- x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
165+ x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid
147166
148167with qe.Timer():
149- z_max_numpy = np.max(f(x, y))
168+ z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works
150169```
151170
152171In the vectorized version, all the looping takes place in compiled code.
@@ -159,11 +178,30 @@ The output should be close to one:
159178print(f"NumPy result: {z_max_numpy:.6f}")
160179```
161180
181+ ### Memory Issues
182+
183+ So we have the right solution reasonable time --- but memory usage is huge.
184+
185+ While the flat arrays are low-memory
186+
187+ ``` {code-cell} ipython3
188+ grid.nbytes
189+ ```
190+
191+ the mesh grids are two-dimensional and hence very memory intensive
192+
193+ ``` {code-cell} ipython3
194+ x_mesh.nbytes + y_mesh.nbytes
195+ ```
196+
197+ Moreover, NumPy's eager execution creates many intermediate arrays of the same size!
198+
199+ This kind of memory usage can be a big problem in actual research calculations.
162200
163201
164202### A Comparison with Numba
165203
166- Now let 's see if we can achieve better performance using Numba with a simple loop.
204+ Let 's see if we can achieve better performance using Numba with a simple loop.
167205
168206``` {code-cell} ipython3
169207@numba.jit
@@ -194,15 +232,13 @@ with qe.Timer():
194232 compute_max_numba(grid)
195233```
196234
197- Depending on your machine, the Numba version might be either slower or faster than NumPy.
235+ Notice how we are using almost no memory --- we just need the one-dimensional ` grid `
198236
199- In most cases we find that Numba is slightly better .
237+ Moreover, execution speed is good .
200238
201- On the one hand, NumPy combines efficient arithmetic with some
202- multithreading, which provides an advantage.
239+ On most machines, the Numba version will be somewhat faster than NumPy.
203240
204- On the other hand, the Numba routine uses much less memory, since we are only
205- working with a single one-dimensional grid.
241+ The reason is efficient machine code plus less memory read-write.
206242
207243
208244### Parallelized Numba
@@ -301,27 +337,11 @@ The compilation overhead is a one-time cost that pays off when the function is c
301337
302338### JAX plus vmap
303339
304- There is one problem with both the NumPy code and the JAX code above:
305-
306- While the flat arrays are low-memory
307-
308- ``` {code-cell} ipython3
309- grid.nbytes
310- ```
311-
312- the mesh grids are memory intensive
313-
314- ``` {code-cell} ipython3
315- x_mesh.nbytes + y_mesh.nbytes
316- ```
340+ Because we use'd ` jax.jit ` above, we avoided creating many intermediate arrays.
317341
318- This extra memory usage can be a big problem in actual research calculations .
342+ But we still create the big arrays ` z_max ` , ` x_mesh ` , and ` y_mesh ` .
319343
320- Fortunately, JAX admits a different approach
321- using [ jax.vmap] ( https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html ) .
322-
323- The idea of ` vmap ` is to break vectorization into stages, transforming a
324- function that operates on single values into one that operates on arrays.
344+ Fortunately, we can avoid this by using [ jax.vmap] ( https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html ) .
325345
326346Here's how we can apply it to our problem.
327347
@@ -330,13 +350,13 @@ Here's how we can apply it to our problem.
330350@jax.jit
331351def compute_max_vmap(grid):
332352 # Construct a function that takes the max over all x for given y
333- f_vec_x_max = lambda y: jnp.max(f(grid, y))
353+ compute_column_max = lambda y: jnp.max(f(grid, y))
334354 # Vectorize the function so we can call on all y simultaneously
335- f_vec_max = jax.vmap(f_vec_x_max )
336- # Compute the max across x at every y
337- maxes = f_vec_max (grid)
338- # Compute the max of the maxes and return
339- return jnp.max(maxes )
355+ vectorized_compute_column_max = jax.vmap(compute_column_max )
356+ # Compute the column max at every row
357+ column_maxes = vectorized_compute_column_max (grid)
358+ # Compute the max of the column maxes and return
359+ return jnp.max(column_maxes )
340360```
341361
342362Note that we never create
@@ -345,6 +365,8 @@ Note that we never create
345365* the two-dimensional grid ` y_mesh ` or
346366* the two-dimensional array ` f(x,y) `
347367
368+ Like Numba, we just use the flat array ` grid ` .
369+
348370And because everything is under a single ` @jax.jit ` , the compiler can fuse
349371all operations into one optimized kernel.
350372
@@ -378,18 +400,14 @@ In our view, JAX is the winner for vectorized operations.
378400It dominates NumPy both in terms of speed (via JIT-compilation and
379401parallelization) and memory efficiency (via vmap).
380402
381- Moreover, the ` vmap ` approach can sometimes lead to significantly clearer code.
382-
383- While Numba is impressive, the beauty of JAX is that, with fully vectorized
384- operations, we can run exactly the same code on machines with hardware
385- accelerators and reap all the benefits without extra effort.
403+ It also dominates Numba when run on the GPU.
386404
387- Moreover, JAX already knows how to effectively parallelize many common array
388- operations, which is key to fast execution.
389-
390- For most cases encountered in economics, econometrics, and finance, it is
405+ ``` {note}
406+ Numba can support GPU programming through `numba.cuda` but then we need to
407+ parallelize by hand. For most cases encountered in economics, econometrics, and finance, it is
391408far better to hand over to the JAX compiler for efficient parallelization than to
392- try to hand code these routines ourselves.
409+ try to hand-code these routines ourselves.
410+ ```
393411
394412
395413## Sequential operations
@@ -554,8 +572,6 @@ The JAX versions, on the other hand, require either `lax.fori_loop` or
554572While JAX's ` at[t].set ` syntax does allow element-wise updates, the overall code
555573remains harder to read than the Numba equivalent.
556574
557- For this type of sequential operation, Numba is the clear winner in terms of
558- code clarity and ease of implementation.
559575
560576
561577## Overall recommendations
@@ -573,25 +589,17 @@ than traditional meshgrid-based vectorization.
573589In addition, JAX functions are automatically differentiable, as we explore in
574590{doc}` autodiff ` .
575591
576- For ** sequential operations** , Numba has clear advantages .
592+ For ** sequential operations** , Numba has nicer syntax .
577593
578594The code is natural and readable --- just a Python loop with a decorator ---
579595and performance is excellent.
580596
581597JAX can handle sequential problems via ` lax.fori_loop ` or ` lax.scan ` , but
582598the syntax is less intuitive.
583599
584- ``` {note}
585- One important advantage of `lax.fori_loop` and `lax.scan` is that they
586- support automatic differentiation through the loop, which Numba cannot do.
587- If you need to differentiate through a sequential computation (e.g., computing
588- sensitivities of a trajectory to model parameters), JAX is the better choice
589- despite the less natural syntax.
590- ```
600+ On the other hand, the JAX versions support automatic differentiation.
591601
592- In practice, many problems involve a mix of both patterns.
602+ That might be of interest if, say, we want to compute sensitivities of a
603+ trajectory to model parameters
593604
594- A good rule of thumb: default to JAX for new projects, especially when
595- hardware acceleration or differentiability might be useful, and reach for Numba
596- when you have a tight sequential loop that needs to be fast and readable.
597605
0 commit comments