Skip to content

Commit 1671b37

Browse files
committed
test: add tests of bounds and vmap
1 parent dea6915 commit 1671b37

1 file changed

Lines changed: 23 additions & 0 deletions

File tree

tests/jax/test_bounds_jax.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ def _plus_bounds_far_away_float(bnds):
8585
return bnds, bnds.isDefined()
8686

8787

88+
@jax.vmap
89+
@jax.jit
90+
def _plus_bounds_pos_far_away_float(bnds):
91+
bnds = bnds + jax_galsim.PositionD(x=100, y=110)
92+
return bnds, bnds.isDefined()
93+
94+
8895
def test_bounds_jax_vmap_plus_float():
8996
xmin = jnp.array([9, 10, 11, 12])
9097
xmax = jnp.array([12, 11, 10, 9])
@@ -106,3 +113,19 @@ def test_bounds_jax_vmap_plus_float():
106113
np.testing.assert_array_equal(bnds.xmax[1:], 110)
107114
np.testing.assert_array_equal(bnds.ymin[1:], 100)
108115
np.testing.assert_array_equal(bnds.ymax[1:], 110)
116+
117+
bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax)
118+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
119+
bnds, isdef = _plus_bounds_pos_far_away_float(bnds)
120+
assert bnds.isDefined().shape == (4,)
121+
np.testing.assert_array_equal(bnds.isDefined(), True)
122+
np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True)
123+
assert bnds.xmin[0] == 9
124+
assert bnds.xmax[0] == 100
125+
assert bnds.ymin[0] == 9
126+
assert bnds.ymax[0] == 110
127+
128+
np.testing.assert_array_equal(bnds.xmin[1:], 100)
129+
np.testing.assert_array_equal(bnds.xmax[1:], 100)
130+
np.testing.assert_array_equal(bnds.ymin[1:], 110)
131+
np.testing.assert_array_equal(bnds.ymax[1:], 110)

0 commit comments

Comments
 (0)