|
11 | 11 | """ |
12 | 12 | from functools import reduce |
13 | 13 | import numpy as np |
| 14 | +import torch |
14 | 15 |
|
15 | 16 |
|
16 | 17 | def ckron(*arrays): |
@@ -121,3 +122,49 @@ def _gridmake2(x1, x2): |
121 | 122 | return np.column_stack([first, second]) |
122 | 123 | else: |
123 | 124 | raise NotImplementedError("Come back here") |
| 125 | + |
| 126 | + |
| 127 | +def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: |
| 128 | + """ |
| 129 | + PyTorch version of _gridmake2. |
| 130 | +
|
| 131 | + Expands two tensors into a matrix where rows span the cartesian product |
| 132 | + of combinations of the input tensors. Each column of the input tensors |
| 133 | + will correspond to one column of the output matrix. |
| 134 | +
|
| 135 | + Parameters |
| 136 | + ---------- |
| 137 | + x1 : torch.Tensor |
| 138 | + First tensor to be expanded. |
| 139 | +
|
| 140 | + x2 : torch.Tensor |
| 141 | + Second tensor to be expanded. |
| 142 | +
|
| 143 | + Returns |
| 144 | + ------- |
| 145 | + out : torch.Tensor |
| 146 | + The cartesian product of combinations of the input tensors. |
| 147 | +
|
| 148 | + Notes |
| 149 | + ----- |
| 150 | + Based on original function ``gridmake2`` in CompEcon toolbox by |
| 151 | + Miranda and Fackler. |
| 152 | +
|
| 153 | + References |
| 154 | + ---------- |
| 155 | + Miranda, Mario J, and Paul L Fackler. Applied Computational Economics |
| 156 | + and Finance, MIT Press, 2002. |
| 157 | +
|
| 158 | + """ |
| 159 | + if x1.dim() == 1 and x2.dim() == 1: |
| 160 | + # tile x1 by x2.shape[0] times, repeat_interleave x2 by x1.shape[0] |
| 161 | + first = x1.tile(x2.shape[0]) |
| 162 | + second = x2.repeat_interleave(x1.shape[0]) |
| 163 | + return torch.column_stack([first, second]) |
| 164 | + elif x1.dim() > 1 and x2.dim() == 1: |
| 165 | + # tile x1 along first dimension |
| 166 | + first = x1.tile(x2.shape[0], 1) |
| 167 | + second = x2.repeat_interleave(x1.shape[0]) |
| 168 | + return torch.column_stack([first, second]) |
| 169 | + else: |
| 170 | + raise NotImplementedError("Come back here") |
0 commit comments