Skip to content

Commit bab9ae9

Browse files
author
Codeflash Bot
committed
pytorch version
1 parent 638adca commit bab9ae9

1 file changed

Lines changed: 306 additions & 0 deletions

File tree

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
5+
from code_to_optimize.discrete_riccati import _gridmake2, _gridmake2_torch
6+
7+
8+
class TestGridmake2TorchCPU:
9+
"""Tests for _gridmake2_torch with CPU tensors."""
10+
11+
def test_both_1d_simple(self):
12+
"""Test with two simple 1D tensors."""
13+
x1 = torch.tensor([1, 2, 3])
14+
x2 = torch.tensor([10, 20])
15+
16+
result = _gridmake2_torch(x1, x2)
17+
18+
# Expected: x1 tiled x2.shape[0] times, x2 repeat_interleaved x1.shape[0]
19+
# x1 tiled: [1, 2, 3, 1, 2, 3]
20+
# x2 repeated: [10, 10, 10, 20, 20, 20]
21+
expected = torch.tensor([
22+
[1, 10],
23+
[2, 10],
24+
[3, 10],
25+
[1, 20],
26+
[2, 20],
27+
[3, 20],
28+
])
29+
assert torch.equal(result, expected)
30+
31+
def test_both_1d_matches_numpy(self):
32+
"""Test that torch version matches numpy version for 1D inputs."""
33+
x1_np = np.array([1.0, 2.0, 3.0, 4.0])
34+
x2_np = np.array([10.0, 20.0, 30.0])
35+
36+
x1_torch = torch.tensor(x1_np)
37+
x2_torch = torch.tensor(x2_np)
38+
39+
result_np = _gridmake2(x1_np, x2_np)
40+
result_torch = _gridmake2_torch(x1_torch, x2_torch)
41+
42+
np.testing.assert_array_almost_equal(result_np, result_torch.numpy())
43+
44+
def test_both_1d_single_element(self):
45+
"""Test with single element tensors."""
46+
x1 = torch.tensor([5])
47+
x2 = torch.tensor([10])
48+
49+
result = _gridmake2_torch(x1, x2)
50+
51+
expected = torch.tensor([[5, 10]])
52+
assert torch.equal(result, expected)
53+
54+
def test_both_1d_float_tensors(self):
55+
"""Test with float tensors."""
56+
x1 = torch.tensor([1.5, 2.5])
57+
x2 = torch.tensor([0.1, 0.2, 0.3])
58+
59+
result = _gridmake2_torch(x1, x2)
60+
61+
assert result.shape == (6, 2)
62+
assert result.dtype == torch.float32
63+
64+
def test_2d_and_1d_simple(self):
65+
"""Test with 2D x1 and 1D x2."""
66+
x1 = torch.tensor([[1, 2], [3, 4]])
67+
x2 = torch.tensor([10, 20])
68+
69+
result = _gridmake2_torch(x1, x2)
70+
71+
# x1 tiled along first dim: [[1, 2], [3, 4], [1, 2], [3, 4]]
72+
# x2 repeated: [10, 10, 20, 20]
73+
# column_stack: [[1, 2, 10], [3, 4, 10], [1, 2, 20], [3, 4, 20]]
74+
expected = torch.tensor([
75+
[1, 2, 10],
76+
[3, 4, 10],
77+
[1, 2, 20],
78+
[3, 4, 20],
79+
])
80+
assert torch.equal(result, expected)
81+
82+
def test_2d_and_1d_matches_numpy(self):
83+
"""Test that torch version matches numpy version for 2D, 1D inputs."""
84+
x1_np = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
85+
x2_np = np.array([10.0, 20.0])
86+
87+
x1_torch = torch.tensor(x1_np)
88+
x2_torch = torch.tensor(x2_np)
89+
90+
result_np = _gridmake2(x1_np, x2_np)
91+
result_torch = _gridmake2_torch(x1_torch, x2_torch)
92+
93+
np.testing.assert_array_almost_equal(result_np, result_torch.numpy())
94+
95+
def test_2d_and_1d_single_column(self):
96+
"""Test with 2D x1 having a single column and 1D x2."""
97+
x1 = torch.tensor([[1], [2], [3]])
98+
x2 = torch.tensor([10, 20])
99+
100+
result = _gridmake2_torch(x1, x2)
101+
102+
expected = torch.tensor([
103+
[1, 10],
104+
[2, 10],
105+
[3, 10],
106+
[1, 20],
107+
[2, 20],
108+
[3, 20],
109+
])
110+
assert torch.equal(result, expected)
111+
112+
def test_output_shape_1d_1d(self):
113+
"""Test output shape for two 1D tensors."""
114+
x1 = torch.tensor([1, 2, 3, 4, 5])
115+
x2 = torch.tensor([10, 20, 30])
116+
117+
result = _gridmake2_torch(x1, x2)
118+
119+
# Shape should be (len(x1) * len(x2), 2)
120+
assert result.shape == (15, 2)
121+
122+
def test_output_shape_2d_1d(self):
123+
"""Test output shape for 2D and 1D tensors."""
124+
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Shape (2, 3)
125+
x2 = torch.tensor([10, 20, 30, 40]) # Shape (4,)
126+
127+
result = _gridmake2_torch(x1, x2)
128+
129+
# Shape should be (2 * 4, 3 + 1) = (8, 4)
130+
assert result.shape == (8, 4)
131+
132+
def test_not_implemented_for_2d_2d(self):
133+
"""Test that NotImplementedError is raised for two 2D tensors."""
134+
x1 = torch.tensor([[1, 2], [3, 4]])
135+
x2 = torch.tensor([[10, 20], [30, 40]])
136+
137+
with pytest.raises(NotImplementedError, match="Come back here"):
138+
_gridmake2_torch(x1, x2)
139+
140+
def test_not_implemented_for_1d_2d(self):
141+
"""Test that NotImplementedError is raised for 1D and 2D tensors."""
142+
x1 = torch.tensor([1, 2, 3])
143+
x2 = torch.tensor([[10, 20], [30, 40]])
144+
145+
with pytest.raises(NotImplementedError, match="Come back here"):
146+
_gridmake2_torch(x1, x2)
147+
148+
def test_preserves_dtype_int(self):
149+
"""Test that integer dtype is preserved."""
150+
x1 = torch.tensor([1, 2, 3], dtype=torch.int32)
151+
x2 = torch.tensor([10, 20], dtype=torch.int32)
152+
153+
result = _gridmake2_torch(x1, x2)
154+
155+
assert result.dtype == torch.int32
156+
157+
def test_preserves_dtype_float64(self):
158+
"""Test that float64 dtype is preserved."""
159+
x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64)
160+
x2 = torch.tensor([10.0, 20.0], dtype=torch.float64)
161+
162+
result = _gridmake2_torch(x1, x2)
163+
164+
assert result.dtype == torch.float64
165+
166+
def test_large_tensors(self):
167+
"""Test with larger tensors."""
168+
x1 = torch.arange(100)
169+
x2 = torch.arange(50)
170+
171+
result = _gridmake2_torch(x1, x2)
172+
173+
assert result.shape == (5000, 2)
174+
# Verify first and last elements
175+
assert result[0, 0] == 0 and result[0, 1] == 0
176+
assert result[-1, 0] == 99 and result[-1, 1] == 49
177+
178+
179+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
180+
class TestGridmake2TorchCUDA:
181+
"""Tests for _gridmake2_torch with CUDA tensors."""
182+
183+
def test_both_1d_simple_cuda(self):
184+
"""Test with two simple 1D CUDA tensors."""
185+
x1 = torch.tensor([1, 2, 3], device="cuda")
186+
x2 = torch.tensor([10, 20], device="cuda")
187+
188+
result = _gridmake2_torch(x1, x2)
189+
190+
expected = torch.tensor([
191+
[1, 10],
192+
[2, 10],
193+
[3, 10],
194+
[1, 20],
195+
[2, 20],
196+
[3, 20],
197+
], device="cuda")
198+
assert result.device.type == "cuda"
199+
assert torch.equal(result, expected)
200+
201+
def test_both_1d_matches_cpu(self):
202+
"""Test that CUDA version matches CPU version."""
203+
x1_cpu = torch.tensor([1.0, 2.0, 3.0, 4.0])
204+
x2_cpu = torch.tensor([10.0, 20.0, 30.0])
205+
206+
x1_cuda = x1_cpu.cuda()
207+
x2_cuda = x2_cpu.cuda()
208+
209+
result_cpu = _gridmake2_torch(x1_cpu, x2_cpu)
210+
result_cuda = _gridmake2_torch(x1_cuda, x2_cuda)
211+
212+
assert result_cuda.device.type == "cuda"
213+
torch.testing.assert_close(result_cpu, result_cuda.cpu())
214+
215+
def test_2d_and_1d_cuda(self):
216+
"""Test with 2D x1 and 1D x2 on CUDA."""
217+
x1 = torch.tensor([[1, 2], [3, 4]], device="cuda")
218+
x2 = torch.tensor([10, 20], device="cuda")
219+
220+
result = _gridmake2_torch(x1, x2)
221+
222+
expected = torch.tensor([
223+
[1, 2, 10],
224+
[3, 4, 10],
225+
[1, 2, 20],
226+
[3, 4, 20],
227+
], device="cuda")
228+
assert result.device.type == "cuda"
229+
assert torch.equal(result, expected)
230+
231+
def test_2d_and_1d_matches_cpu(self):
232+
"""Test that CUDA version matches CPU version for 2D, 1D inputs."""
233+
x1_cpu = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
234+
x2_cpu = torch.tensor([10.0, 20.0])
235+
236+
x1_cuda = x1_cpu.cuda()
237+
x2_cuda = x2_cpu.cuda()
238+
239+
result_cpu = _gridmake2_torch(x1_cpu, x2_cpu)
240+
result_cuda = _gridmake2_torch(x1_cuda, x2_cuda)
241+
242+
assert result_cuda.device.type == "cuda"
243+
torch.testing.assert_close(result_cpu, result_cuda.cpu())
244+
245+
def test_output_stays_on_cuda(self):
246+
"""Test that output tensor stays on CUDA device."""
247+
x1 = torch.tensor([1, 2, 3], device="cuda")
248+
x2 = torch.tensor([10, 20], device="cuda")
249+
250+
result = _gridmake2_torch(x1, x2)
251+
252+
assert result.is_cuda
253+
254+
def test_preserves_dtype_float32_cuda(self):
255+
"""Test that float32 dtype is preserved on CUDA."""
256+
x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda")
257+
x2 = torch.tensor([10.0, 20.0], dtype=torch.float32, device="cuda")
258+
259+
result = _gridmake2_torch(x1, x2)
260+
261+
assert result.dtype == torch.float32
262+
assert result.device.type == "cuda"
263+
264+
def test_preserves_dtype_float64_cuda(self):
265+
"""Test that float64 dtype is preserved on CUDA."""
266+
x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64, device="cuda")
267+
x2 = torch.tensor([10.0, 20.0], dtype=torch.float64, device="cuda")
268+
269+
result = _gridmake2_torch(x1, x2)
270+
271+
assert result.dtype == torch.float64
272+
assert result.device.type == "cuda"
273+
274+
def test_large_tensors_cuda(self):
275+
"""Test with larger tensors on CUDA."""
276+
x1 = torch.arange(100, device="cuda")
277+
x2 = torch.arange(50, device="cuda")
278+
279+
result = _gridmake2_torch(x1, x2)
280+
281+
assert result.shape == (5000, 2)
282+
assert result.device.type == "cuda"
283+
# Verify first and last elements
284+
assert result[0, 0].item() == 0 and result[0, 1].item() == 0
285+
assert result[-1, 0].item() == 99 and result[-1, 1].item() == 49
286+
287+
def test_not_implemented_for_2d_2d_cuda(self):
288+
"""Test that NotImplementedError is raised for two 2D CUDA tensors."""
289+
x1 = torch.tensor([[1, 2], [3, 4]], device="cuda")
290+
x2 = torch.tensor([[10, 20], [30, 40]], device="cuda")
291+
292+
with pytest.raises(NotImplementedError, match="Come back here"):
293+
_gridmake2_torch(x1, x2)
294+
295+
def test_matches_numpy_via_cpu_conversion(self):
296+
"""Test CUDA result matches numpy version via CPU conversion."""
297+
x1_np = np.array([1.0, 2.0, 3.0, 4.0])
298+
x2_np = np.array([10.0, 20.0, 30.0])
299+
300+
x1_cuda = torch.tensor(x1_np, device="cuda")
301+
x2_cuda = torch.tensor(x2_np, device="cuda")
302+
303+
result_np = _gridmake2(x1_np, x2_np)
304+
result_cuda = _gridmake2_torch(x1_cuda, x2_cuda)
305+
306+
np.testing.assert_array_almost_equal(result_np, result_cuda.cpu().numpy())

0 commit comments

Comments
 (0)