Skip to content

Commit dea6915

Browse files
committed
test: add tests of bounds and vmap
1 parent 0d07661 commit dea6915

1 file changed

Lines changed: 91 additions & 4 deletions

File tree

tests/jax/test_bounds_jax.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,101 @@
88
@jax.vmap
99
@jax.jit
1010
def _make_bounds_float(xmin, ymin, xmax, ymax):
11-
bds = jax_galsim.BoundsD(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
12-
return bds, bds.isDefined()
11+
bnds = jax_galsim.BoundsD(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
12+
return bnds, bnds.isDefined()
1313

1414

1515
def test_bounds_jax_vmap_isdefined_float():
1616
xmin = jnp.array([9, 10, 11, 12])
1717
xmax = jnp.array([12, 11, 10, 9])
1818
ymin = jnp.array([9, 11, 10, 12])
1919
ymax = jnp.array([10, 10, 10, 10])
20-
bds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax)
21-
np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True)
20+
bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax)
21+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
22+
23+
24+
@jax.vmap
25+
@jax.jit
26+
def _and_bounds_empty_float(bnds):
27+
bnds = bnds & jax_galsim.BoundsD()
28+
return bnds, bnds.isDefined()
29+
30+
31+
@jax.vmap
32+
@jax.jit
33+
def _and_bounds_float(bnds):
34+
bnds = bnds & jax_galsim.BoundsD(xmin=10, xmax=11, ymin=10, ymax=11)
35+
return bnds, bnds.isDefined()
36+
37+
38+
@jax.vmap
39+
@jax.jit
40+
def _and_bounds_far_away_float(bnds):
41+
bnds = bnds & jax_galsim.BoundsD(xmin=100, xmax=110, ymin=100, ymax=110)
42+
return bnds, bnds.isDefined()
43+
44+
45+
def test_bounds_jax_vmap_and_isdefined_float():
46+
xmin = jnp.array([9, 10, 11, 12])
47+
xmax = jnp.array([12, 11, 10, 9])
48+
ymin = jnp.array([9, 11, 10, 12])
49+
ymax = jnp.array([10, 10, 10, 10])
50+
51+
bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax)
52+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
53+
bnds, isdef = _and_bounds_empty_float(bnds)
54+
assert bnds.isDefined().shape == (4,)
55+
assert not jnp.any(bnds.isDefined())
56+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
57+
np.testing.assert_array_equal(bnds.isDefined(), False)
58+
59+
bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax)
60+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
61+
bnds, isdef = _and_bounds_float(bnds)
62+
assert bnds.isDefined().shape == (4,)
63+
np.testing.assert_array_equal(
64+
bnds.isDefined(), jnp.array([True, False, False, False])
65+
)
66+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
67+
assert bnds.xmin[0] == 10
68+
assert bnds.xmax[0] == 11
69+
assert bnds.ymin[0] == 10
70+
assert bnds.ymax[0] == 10
71+
72+
bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax)
73+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
74+
bnds, isdef = _and_bounds_far_away_float(bnds)
75+
assert bnds.isDefined().shape == (4,)
76+
assert not jnp.any(bnds.isDefined())
77+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
78+
np.testing.assert_array_equal(bnds.isDefined(), False)
79+
80+
81+
@jax.vmap
82+
@jax.jit
83+
def _plus_bounds_far_away_float(bnds):
84+
bnds = bnds + jax_galsim.BoundsD(xmin=100, xmax=110, ymin=100, ymax=110)
85+
return bnds, bnds.isDefined()
86+
87+
88+
def test_bounds_jax_vmap_plus_float():
89+
xmin = jnp.array([9, 10, 11, 12])
90+
xmax = jnp.array([12, 11, 10, 9])
91+
ymin = jnp.array([9, 11, 10, 12])
92+
ymax = jnp.array([10, 10, 10, 10])
93+
94+
bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax)
95+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
96+
bnds, isdef = _plus_bounds_far_away_float(bnds)
97+
assert bnds.isDefined().shape == (4,)
98+
np.testing.assert_array_equal(bnds.isDefined(), True)
99+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
100+
assert bnds.xmin[0] == 9
101+
assert bnds.xmax[0] == 110
102+
assert bnds.ymin[0] == 9
103+
assert bnds.ymax[0] == 110
104+
105+
np.testing.assert_array_equal(bnds.xmin[1:], 100)
106+
np.testing.assert_array_equal(bnds.xmax[1:], 110)
107+
np.testing.assert_array_equal(bnds.ymin[1:], 100)
108+
np.testing.assert_array_equal(bnds.ymax[1:], 110)

0 commit comments

Comments
 (0)