Skip to content

Commit 5b9058c

Browse files
Optimize _gridmake2
The optimized code achieves a **10x speedup** (1038%) by replacing NumPy's high-level array operations with JIT-compiled explicit loops via Numba's `@njit` decorator. ## Key Optimizations **1. Numba JIT Compilation with `@njit(cache=True)`** - Eliminates Python interpreter overhead by compiling to machine code - The `cache=True` flag stores compiled code between runs, avoiding recompilation cost - Particularly effective for loops, which NumPy operations like `tile`, `repeat`, and `column_stack` use internally but with Python overhead **2. Preallocated Output Arrays with Explicit Loops** - **Original approach**: `np.column_stack([np.tile(x1, x2.shape[0]), np.repeat(x2, x1.shape[0])])` creates three temporary arrays (tile result, repeat result, then column_stack result) - **Optimized approach**: Pre-allocates a single output array with exact size `(x1.shape[0] * x2.shape[0], 2)` and fills it directly via nested loops - Eliminates intermediate array allocations and memory copies **3. Direct Memory Access** - Line profiler shows the original code spends 77.9% of time in `np.column_stack` and related operations - The optimized version replaces these with direct index assignments (`out[idx, 0] = x1[i]`), which Numba compiles to efficient memory writes ## Performance Context From `function_references`, `_gridmake2` is called recursively within `gridmake()` when building cartesian products of multiple arrays. For `d > 2` dimensions, the function is called `d-1` times in a loop. This means: - **Hot path impact**: The 10x speedup compounds across multiple calls when expanding 3+ dimensional grids - **Memory efficiency**: For large input arrays, avoiding temporary allocations becomes increasingly important ## Test Case Suitability The optimization excels when: - Building cartesian products of moderately-sized vectors (e.g., 100-1000 elements each) - Called repeatedly in loops (as in the recursive `gridmake` case) - Input arrays have consistent dtypes (Numba's type specialization works best here) The line profiler confirms the bottleneck was NumPy's high-level operations, which this optimization directly addresses through low-level compiled code.
1 parent bab9ae9 commit 5b9058c

1 file changed

Lines changed: 34 additions & 23 deletions

File tree

code_to_optimize/discrete_riccati.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
"""
2-
Utility functions used in CompEcon
1+
"""Utility functions used in CompEcon
32
43
Based routines found in the CompEcon toolbox by Miranda and Fackler.
54
@@ -9,14 +8,16 @@
98
and Finance, MIT Press, 2002.
109
1110
"""
11+
1212
from functools import reduce
13+
1314
import numpy as np
1415
import torch
16+
from numba import njit
1517

1618

1719
def ckron(*arrays):
18-
"""
19-
Repeatedly applies the np.kron function to an arbitrary number of
20+
"""Repeatedly applies the np.kron function to an arbitrary number of
2021
input arrays
2122
2223
Parameters
@@ -43,8 +44,7 @@ def ckron(*arrays):
4344

4445

4546
def gridmake(*arrays):
46-
"""
47-
Expands one or more vectors (or matrices) into a matrix where rows span the
47+
"""Expands one or more vectors (or matrices) into a matrix where rows span the
4848
cartesian product of combinations of the input arrays. Each column of the
4949
input arrays will correspond to one column of the output matrix.
5050
@@ -79,13 +79,12 @@ def gridmake(*arrays):
7979
out = _gridmake2(out, arr)
8080

8181
return out
82-
else:
83-
raise NotImplementedError("Come back here")
82+
raise NotImplementedError("Come back here")
8483

8584

85+
@njit(cache=True)
8686
def _gridmake2(x1, x2):
87-
"""
88-
Expands two vectors (or matrices) into a matrix where rows span the
87+
"""Expands two vectors (or matrices) into a matrix where rows span the
8988
cartesian product of combinations of the input arrays. Each column of the
9089
input arrays will correspond to one column of the output matrix.
9190
@@ -114,19 +113,32 @@ def _gridmake2(x1, x2):
114113
115114
"""
116115
if x1.ndim == 1 and x2.ndim == 1:
117-
return np.column_stack([np.tile(x1, x2.shape[0]),
118-
np.repeat(x2, x1.shape[0])])
119-
elif x1.ndim > 1 and x2.ndim == 1:
120-
first = np.tile(x1, (x2.shape[0], 1))
121-
second = np.repeat(x2, x1.shape[0])
122-
return np.column_stack([first, second])
123-
else:
124-
raise NotImplementedError("Come back here")
116+
out = np.empty((x1.shape[0] * x2.shape[0], 2), dtype=x1.dtype)
117+
idx = 0
118+
for j in range(x2.shape[0]):
119+
for i in range(x1.shape[0]):
120+
out[idx, 0] = x1[i]
121+
out[idx, 1] = x2[j]
122+
idx += 1
123+
return out
124+
if x1.ndim > 1 and x2.ndim == 1:
125+
n1 = x1.shape[0]
126+
n2 = x2.shape[0]
127+
ncols = x1.shape[1] + 1
128+
out = np.empty((n1 * n2, ncols), dtype=x1.dtype)
129+
idx = 0
130+
for j in range(n2):
131+
for i in range(n1):
132+
for k in range(x1.shape[1]):
133+
out[idx, k] = x1[i, k]
134+
out[idx, -1] = x2[j]
135+
idx += 1
136+
return out
137+
raise NotImplementedError("Come back here")
125138

126139

127140
def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
128-
"""
129-
PyTorch version of _gridmake2.
141+
"""PyTorch version of _gridmake2.
130142
131143
Expands two tensors into a matrix where rows span the cartesian product
132144
of combinations of the input tensors. Each column of the input tensors
@@ -161,10 +173,9 @@ def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
161173
first = x1.tile(x2.shape[0])
162174
second = x2.repeat_interleave(x1.shape[0])
163175
return torch.column_stack([first, second])
164-
elif x1.dim() > 1 and x2.dim() == 1:
176+
if x1.dim() > 1 and x2.dim() == 1:
165177
# tile x1 along first dimension
166178
first = x1.tile(x2.shape[0], 1)
167179
second = x2.repeat_interleave(x1.shape[0])
168180
return torch.column_stack([first, second])
169-
else:
170-
raise NotImplementedError("Come back here")
181+
raise NotImplementedError("Come back here")

0 commit comments

Comments
 (0)