Skip to content

Commit 638adca

Browse files
author
Codeflash Bot
committed
pytorch version
1 parent d6ad12d commit 638adca

1 file changed

Lines changed: 47 additions & 0 deletions

File tree

code_to_optimize/discrete_riccati.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212
from functools import reduce
1313
import numpy as np
14+
import torch
1415

1516

1617
def ckron(*arrays):
@@ -121,3 +122,49 @@ def _gridmake2(x1, x2):
121122
return np.column_stack([first, second])
122123
else:
123124
raise NotImplementedError("Come back here")
125+
126+
127+
def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
128+
"""
129+
PyTorch version of _gridmake2.
130+
131+
Expands two tensors into a matrix where rows span the cartesian product
132+
of combinations of the input tensors. Each column of the input tensors
133+
will correspond to one column of the output matrix.
134+
135+
Parameters
136+
----------
137+
x1 : torch.Tensor
138+
First tensor to be expanded.
139+
140+
x2 : torch.Tensor
141+
Second tensor to be expanded.
142+
143+
Returns
144+
-------
145+
out : torch.Tensor
146+
The cartesian product of combinations of the input tensors.
147+
148+
Notes
149+
-----
150+
Based on original function ``gridmake2`` in CompEcon toolbox by
151+
Miranda and Fackler.
152+
153+
References
154+
----------
155+
Miranda, Mario J, and Paul L Fackler. Applied Computational Economics
156+
and Finance, MIT Press, 2002.
157+
158+
"""
159+
if x1.dim() == 1 and x2.dim() == 1:
160+
# tile x1 by x2.shape[0] times, repeat_interleave x2 by x1.shape[0]
161+
first = x1.tile(x2.shape[0])
162+
second = x2.repeat_interleave(x1.shape[0])
163+
return torch.column_stack([first, second])
164+
elif x1.dim() > 1 and x2.dim() == 1:
165+
# tile x1 along first dimension
166+
first = x1.tile(x2.shape[0], 1)
167+
second = x2.repeat_interleave(x1.shape[0])
168+
return torch.column_stack([first, second])
169+
else:
170+
raise NotImplementedError("Come back here")

0 commit comments

Comments
 (0)