@@ -29,10 +29,12 @@ def _Knu(nu, x):
2929@implements (
3030 _galsim .Moffat ,
3131 lax_description = """\
32- The LAX version of the Moffat profile
32+ The JAX-GalSim version of the Moffat profile
3333
3434- does not support truncation or beta < 1.1
3535- does not support gsparams.maxk_thresholds > 0.1
36+ - does not support autodiff with respect to the `beta` parameter
37+ for Fourier-space evaluations
3638""" ,
3739)
3840@register_pytree_node_class
@@ -333,8 +335,10 @@ def _kValue_untrunc_interp_coeffs(self):
333335 k_min = 0
334336 k_max = self ._maxk
335337 k = jnp .linspace (k_min , k_max , n_pts )
338+
339+ beta = jax .lax .stop_gradient (self .beta )
336340 vals = self ._kValue_untrunc_func (
337- self . beta ,
341+ beta ,
338342 k ,
339343 self ._knorm_bis ,
340344 self ._knorm ,
@@ -343,9 +347,7 @@ def _kValue_untrunc_interp_coeffs(self):
343347
344348 # slope to match the interpolant onto an asymptotic expansion of kv
345349 # that is kv(x) ~ sqrt(pi/2/x) * exp(-x) * (1 + slp/x)
346- aval = self ._kValue_untrunc_asymp_func (
347- self .beta , k [- 1 ], self ._knorm_bis , self ._r0
348- )
350+ aval = self ._kValue_untrunc_asymp_func (beta , k [- 1 ], self ._knorm_bis , self ._r0 )
349351 slp = (vals [- 1 ] / aval - 1 ) * k [- 1 ] * self ._r0
350352
351353 return k , vals , akima_interp_coeffs (k , vals ), slp
@@ -355,6 +357,9 @@ def _kValue(self, kpos):
355357 """computation of the Moffat response in k-space with interpolant + expansions
356358 kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image)
357359 """
360+ # we cannot compute gradients with respect to beta
361+ beta = jax .lax .stop_gradient (self .beta )
362+
358363 k = safe_sqrt (kpos .x ** 2 + kpos .y ** 2 )
359364 out_shape = jnp .shape (k )
360365 k = jnp .atleast_1d (k )
@@ -364,7 +369,12 @@ def _kValue(self, kpos):
364369 k_msk = jnp .where (k > 0 , k , k_ [1 ])
365370 res = jnp .where (
366371 k > k_ [- 1 ],
367- self ._kValue_untrunc_asymp_func (self .beta , k_msk , self ._knorm_bis , self ._r0 )
372+ self ._kValue_untrunc_asymp_func (
373+ beta ,
374+ k_msk ,
375+ self ._knorm_bis ,
376+ self ._r0 ,
377+ )
368378 * (1.0 + slp / k_msk / self ._r0 ),
369379 res ,
370380 )
0 commit comments