@@ -85,6 +85,13 @@ def _plus_bounds_far_away_float(bnds):
8585 return bnds , bnds .isDefined ()
8686
8787
88+ @jax .vmap
89+ @jax .jit
90+ def _plus_bounds_pos_far_away_float (bnds ):
91+ bnds = bnds + jax_galsim .PositionD (x = 100 , y = 110 )
92+ return bnds , bnds .isDefined ()
93+
94+
8895def test_bounds_jax_vmap_plus_float ():
8996 xmin = jnp .array ([9 , 10 , 11 , 12 ])
9097 xmax = jnp .array ([12 , 11 , 10 , 9 ])
@@ -106,3 +113,19 @@ def test_bounds_jax_vmap_plus_float():
106113 np .testing .assert_array_equal (bnds .xmax [1 :], 110 )
107114 np .testing .assert_array_equal (bnds .ymin [1 :], 100 )
108115 np .testing .assert_array_equal (bnds .ymax [1 :], 110 )
116+
117+ bnds , isdef = _make_bounds_float (xmin , ymin , xmax , ymax )
118+ np .testing .assert_array_equal (bnds .isDefined (), isdef , strict = True )
119+ bnds , isdef = _plus_bounds_pos_far_away_float (bnds )
120+ assert bnds .isDefined ().shape == (4 ,)
121+ np .testing .assert_array_equal (bnds .isDefined (), True )
122+ np .testing .assert_array_equal (bnds .isDefined (), isdef , strict = True )
123+ assert bnds .xmin [0 ] == 9
124+ assert bnds .xmax [0 ] == 100
125+ assert bnds .ymin [0 ] == 9
126+ assert bnds .ymax [0 ] == 110
127+
128+ np .testing .assert_array_equal (bnds .xmin [1 :], 100 )
129+ np .testing .assert_array_equal (bnds .xmax [1 :], 100 )
130+ np .testing .assert_array_equal (bnds .ymin [1 :], 110 )
131+ np .testing .assert_array_equal (bnds .ymax [1 :], 110 )
0 commit comments