|
| 1 | +"""galax: Galactic Dynamix in Jax.""" |
| 2 | + |
| 3 | +__all__ = ["Chen24StreamDF"] |
| 4 | + |
| 5 | + |
| 6 | +import warnings |
| 7 | +from functools import partial |
| 8 | +from typing import final |
| 9 | + |
| 10 | +import jax |
| 11 | +import jax.random as jr |
| 12 | +from jaxtyping import PRNGKeyArray |
| 13 | + |
| 14 | +import coordinax as cx |
| 15 | +import quaxed.numpy as jnp |
| 16 | + |
| 17 | +import galax._custom_types as gt |
| 18 | +import galax.potential as gp |
| 19 | +from .df_base import AbstractStreamDF |
| 20 | +from galax.dynamics._src.cluster.radius import tidal_radius |
| 21 | +from galax.dynamics._src.register_api import specific_angular_momentum |
| 22 | + |
| 23 | +# ============================================================ |
| 24 | +# Constants |
| 25 | + |
| 26 | +mean = jnp.array([1.6, -30, 0, 1, 20, 0]) |
| 27 | + |
| 28 | +cov = jnp.array( |
| 29 | + [ |
| 30 | + [0.1225, 0, 0, 0, -4.9, 0], |
| 31 | + [0, 529, 0, 0, 0, 0], |
| 32 | + [0, 0, 144, 0, 0, 0], |
| 33 | + [0, 0, 0, 0, 0, 0], |
| 34 | + [-4.9, 0, 0, 0, 400, 0], |
| 35 | + [0, 0, 0, 0, 0, 484], |
| 36 | + ] |
| 37 | +) |
| 38 | + |
| 39 | +# ============================================================ |
| 40 | + |
| 41 | + |
| 42 | +@final |
| 43 | +class Chen24StreamDF(AbstractStreamDF): |
| 44 | + """Chen Stream Distribution Function. |
| 45 | +
|
| 46 | + A class for representing the Chen+2024 distribution function for |
| 47 | + generating stellar streams based on Chen et al. 2024 |
| 48 | + https://ui.adsabs.harvard.edu/abs/2024arXiv240801496C/abstract |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__(self) -> None: |
| 52 | + super().__init__() |
| 53 | + warnings.warn( |
| 54 | + 'Currently only the "no progenitor" version ' |
| 55 | + "of the Chen+24 model is supported!", |
| 56 | + RuntimeWarning, |
| 57 | + stacklevel=1, |
| 58 | + ) |
| 59 | + |
| 60 | + @partial(jax.jit, inline=True) |
| 61 | + def sample( |
| 62 | + self, |
| 63 | + key: PRNGKeyArray, |
| 64 | + potential: gp.AbstractPotential, |
| 65 | + x: gt.BBtQuSz3, |
| 66 | + v: gt.BBtQuSz3, |
| 67 | + prog_mass: gt.BBtFloatQuSz0, |
| 68 | + t: gt.BBtFloatQuSz0, |
| 69 | + ) -> tuple[gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3]: |
| 70 | + """Generate stream particle initial conditions.""" |
| 71 | + # Random number generation |
| 72 | + |
| 73 | + # x_new-hat |
| 74 | + r = jnp.linalg.vector_norm(x, axis=-1, keepdims=True) |
| 75 | + x_new_hat = x / r |
| 76 | + |
| 77 | + # z_new-hat |
| 78 | + L_vec = specific_angular_momentum(x, v) |
| 79 | + z_new_hat = cx.vecs.normalize_vector(L_vec) |
| 80 | + |
| 81 | + # y_new-hat |
| 82 | + phi_vec = v - jnp.sum(v * x_new_hat, axis=-1, keepdims=True) * x_new_hat |
| 83 | + y_new_hat = cx.vecs.normalize_vector(phi_vec) |
| 84 | + |
| 85 | + r_tidal = tidal_radius(potential, x, v, mass=prog_mass, t=t) |
| 86 | + |
| 87 | + # Bill Chen: method="cholesky" doesn't work here! |
| 88 | + posvel = jr.multivariate_normal( |
| 89 | + key, mean, cov, shape=r_tidal.shape, method="svd" |
| 90 | + ) |
| 91 | + |
| 92 | + Dr = posvel[:, 0] * r_tidal |
| 93 | + |
| 94 | + v_esc = jnp.sqrt(2 * potential.constants["G"] * prog_mass / Dr) |
| 95 | + Dv = posvel[:, 3] * v_esc |
| 96 | + |
| 97 | + # convert degrees to radians |
| 98 | + phi = posvel[:, 1] * 0.017453292519943295 |
| 99 | + theta = posvel[:, 2] * 0.017453292519943295 |
| 100 | + alpha = posvel[:, 4] * 0.017453292519943295 |
| 101 | + beta = posvel[:, 5] * 0.017453292519943295 |
| 102 | + |
| 103 | + ctheta, stheta = jnp.cos(theta), jnp.sin(theta) |
| 104 | + cphi, sphi = jnp.cos(phi), jnp.sin(phi) |
| 105 | + calpha, salpha = jnp.cos(alpha), jnp.sin(alpha) |
| 106 | + cbeta, sbeta = jnp.cos(beta), jnp.sin(beta) |
| 107 | + |
| 108 | + # Trailing arm |
| 109 | + x_trail = ( |
| 110 | + x |
| 111 | + + (Dr * ctheta * cphi)[:, None] * x_new_hat |
| 112 | + + (Dr * ctheta * sphi)[:, None] * y_new_hat |
| 113 | + + (Dr * stheta)[:, None] * z_new_hat |
| 114 | + ) |
| 115 | + v_trail = ( |
| 116 | + v |
| 117 | + + (Dv * cbeta * calpha)[:, None] * x_new_hat |
| 118 | + + (Dv * cbeta * salpha)[:, None] * y_new_hat |
| 119 | + + (Dv * sbeta)[:, None] * z_new_hat |
| 120 | + ) |
| 121 | + |
| 122 | + # Leading arm |
| 123 | + x_lead = ( |
| 124 | + x |
| 125 | + - (Dr * ctheta * cphi)[:, None] * x_new_hat |
| 126 | + - (Dr * ctheta * sphi)[:, None] * y_new_hat |
| 127 | + + (Dr * stheta)[:, None] * z_new_hat |
| 128 | + ) |
| 129 | + v_lead = ( |
| 130 | + v |
| 131 | + - (Dv * cbeta * calpha)[:, None] * x_new_hat |
| 132 | + - (Dv * cbeta * salpha)[:, None] * y_new_hat |
| 133 | + + (Dv * sbeta)[:, None] * z_new_hat |
| 134 | + ) |
| 135 | + |
| 136 | + return x_lead, v_lead, x_trail, v_trail |
0 commit comments