Skip to content

Commit 1eb0e86

Browse files
Optimize _gridmake2
## Performance Optimization Summary The optimized code achieves an **884% speedup** (from 1.07ms to 109μs) by replacing NumPy's high-level array operations with **Numba JIT-compiled explicit loops**. ### Key Optimizations **1. Numba JIT Compilation (`@njit(cache=True)`)** - Compiles the function to machine code at runtime, eliminating Python interpreter overhead - The `cache=True` flag stores the compiled version, avoiding recompilation costs on subsequent runs - Particularly effective here because the function contains simple arithmetic and array indexing operations that Numba optimizes well **2. Explicit Loop-Based Construction vs. NumPy Broadcasting** - **Original approach**: Used `np.tile()`, `np.repeat()`, and `np.column_stack()` which create multiple intermediate arrays and perform memory allocations - **Optimized approach**: Pre-allocates the output array once with `np.empty()` and fills it directly using nested loops - This eliminates intermediate array creation and reduces memory allocation overhead **3. Why This Works** From the line profiler, the original code spent: - **76.4%** of time in `np.column_stack([np.tile(...)])` - **8.5%** in `np.repeat()` - **9.3%** in `np.tile()` for the 2D case These NumPy operations, while convenient, involve: - Multiple temporary array allocations - Memory copies during stacking operations - Python-level function call overhead Numba's compiled loops avoid all of this by directly computing each output element in place. ### Impact on Workloads Based on `function_references`, `_gridmake2` is called from `gridmake()` which: - Calls it **once for 2 input arrays** - Calls it **iteratively** for 3+ arrays (once initially, then in a loop for remaining arrays) For multi-array scenarios (3+ inputs), the speedup compounds significantly since `_gridmake2` is called multiple times per `gridmake()` invocation. The nearly **9x speedup** per call translates to substantial gains in computational economics applications where Cartesian products are frequently computed for state space expansions. ### Trade-offs - First call incurs JIT compilation overhead (~tens of milliseconds), but `cache=True` mitigates this for subsequent calls - Code is more verbose but dramatically faster for repeated execution patterns - Best suited for scenarios where the function is called multiple times (amortizing compilation cost)
1 parent bab9ae9 commit 1eb0e86

1 file changed

Lines changed: 33 additions & 23 deletions

File tree

code_to_optimize/discrete_riccati.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
"""
2-
Utility functions used in CompEcon
1+
"""Utility functions used in CompEcon
32
43
Based routines found in the CompEcon toolbox by Miranda and Fackler.
54
@@ -9,14 +8,16 @@
98
and Finance, MIT Press, 2002.
109
1110
"""
11+
1212
from functools import reduce
13+
1314
import numpy as np
1415
import torch
16+
from numba import njit
1517

1618

1719
def ckron(*arrays):
18-
"""
19-
Repeatedly applies the np.kron function to an arbitrary number of
20+
"""Repeatedly applies the np.kron function to an arbitrary number of
2021
input arrays
2122
2223
Parameters
@@ -43,8 +44,7 @@ def ckron(*arrays):
4344

4445

4546
def gridmake(*arrays):
46-
"""
47-
Expands one or more vectors (or matrices) into a matrix where rows span the
47+
"""Expands one or more vectors (or matrices) into a matrix where rows span the
4848
cartesian product of combinations of the input arrays. Each column of the
4949
input arrays will correspond to one column of the output matrix.
5050
@@ -79,13 +79,12 @@ def gridmake(*arrays):
7979
out = _gridmake2(out, arr)
8080

8181
return out
82-
else:
83-
raise NotImplementedError("Come back here")
82+
raise NotImplementedError("Come back here")
8483

8584

85+
@njit(cache=True)
8686
def _gridmake2(x1, x2):
87-
"""
88-
Expands two vectors (or matrices) into a matrix where rows span the
87+
"""Expands two vectors (or matrices) into a matrix where rows span the
8988
cartesian product of combinations of the input arrays. Each column of the
9089
input arrays will correspond to one column of the output matrix.
9190
@@ -114,19 +113,31 @@ def _gridmake2(x1, x2):
114113
115114
"""
116115
if x1.ndim == 1 and x2.ndim == 1:
117-
return np.column_stack([np.tile(x1, x2.shape[0]),
118-
np.repeat(x2, x1.shape[0])])
119-
elif x1.ndim > 1 and x2.ndim == 1:
120-
first = np.tile(x1, (x2.shape[0], 1))
121-
second = np.repeat(x2, x1.shape[0])
122-
return np.column_stack([first, second])
123-
else:
124-
raise NotImplementedError("Come back here")
116+
n1 = x1.shape[0]
117+
n2 = x2.shape[0]
118+
out = np.empty((n1 * n2, 2), dtype=x1.dtype)
119+
for i in range(n2):
120+
for j in range(n1):
121+
out[i * n1 + j, 0] = x1[j]
122+
out[i * n1 + j, 1] = x2[i]
123+
return out
124+
if x1.ndim > 1 and x2.ndim == 1:
125+
n1 = x1.shape[0]
126+
n2 = x2.shape[0]
127+
n_features = x1.shape[1]
128+
out = np.empty((n1 * n2, n_features + 1), dtype=x1.dtype)
129+
for i in range(n2):
130+
for j in range(n1):
131+
idx = i * n1 + j
132+
for k in range(n_features):
133+
out[idx, k] = x1[j, k]
134+
out[idx, n_features] = x2[i]
135+
return out
136+
raise NotImplementedError("Come back here")
125137

126138

127139
def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
128-
"""
129-
PyTorch version of _gridmake2.
140+
"""PyTorch version of _gridmake2.
130141
131142
Expands two tensors into a matrix where rows span the cartesian product
132143
of combinations of the input tensors. Each column of the input tensors
@@ -161,10 +172,9 @@ def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
161172
first = x1.tile(x2.shape[0])
162173
second = x2.repeat_interleave(x1.shape[0])
163174
return torch.column_stack([first, second])
164-
elif x1.dim() > 1 and x2.dim() == 1:
175+
if x1.dim() > 1 and x2.dim() == 1:
165176
# tile x1 along first dimension
166177
first = x1.tile(x2.shape[0], 1)
167178
second = x2.repeat_interleave(x1.shape[0])
168179
return torch.column_stack([first, second])
169-
else:
170-
raise NotImplementedError("Come back here")
180+
raise NotImplementedError("Come back here")

0 commit comments

Comments
 (0)