Skip to content

Commit 368202a

Browse files
committed
✨ feat(dynamics): stream simulator
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
1 parent 7d27b80 commit 368202a

8 files changed

Lines changed: 676 additions & 4 deletions

File tree

src/galax/dynamics/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
"AbstractSolver",
1616
"DynamicsSolver",
1717
# mockstream
18+
"StreamSimulator",
1819
"MockStreamArm",
1920
"MockStream",
20-
"MockStreamGenerator",
21+
"MockStreamGenerator", # TODO: deprecate
2122
# mockstream.df
2223
"AbstractStreamDF",
2324
"FardalStreamDF",
@@ -47,6 +48,7 @@
4748
MockStream,
4849
MockStreamArm,
4950
MockStreamGenerator,
51+
StreamSimulator,
5052
)
5153
from .solve import AbstractSolver, DynamicsSolver
5254

src/galax/dynamics/_src/mockstream/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,19 @@
55
"""
66

77
__all__ = [
8+
"StreamSimulator",
9+
# Coordinates
810
"MockStream",
911
"MockStreamArm",
12+
# Phase-Space Distribution
13+
"AbstractStreamDF",
14+
"Fardal15StreamDF",
15+
"Chen24StreamDF",
1016
]
1117

1218
from .arm import MockStreamArm
1319
from .core import MockStream
20+
from .df_base import AbstractStreamDF
21+
from .df_chen24 import Chen24StreamDF
22+
from .df_fardal15 import Fardal15StreamDF
23+
from .simulate import StreamSimulator

src/galax/dynamics/_src/mockstream/arm.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
from typing import Any, ClassVar, Protocol, cast, final, runtime_checkable
66

7+
import diffrax as dfx
78
import equinox as eqx
89
from plum import dispatch
910

1011
import coordinax as cx
1112
import quaxed.numpy as jnp
1213
import unxt as u
14+
from unxt.quantity import BareQuantity
1315

1416
import galax._custom_types as gt
1517
import galax.coordinates as gc
@@ -67,6 +69,41 @@ def _shape_tuple(self) -> tuple[gt.Shape, gc.ComponentShapeTuple]:
6769

6870
#####################################################################
6971

72+
73+
@gc.AbstractPhaseSpaceObject.from_.dispatch # type: ignore[attr-defined,misc]
74+
def from_(
75+
cls: type[MockStreamArm],
76+
soln: dfx.Solution,
77+
/,
78+
*,
79+
release_time: gt.BBtQuSz0,
80+
frame: cx.frames.AbstractReferenceFrame,
81+
units: u.AbstractUnitSystem, # not dispatched on, but required
82+
unbatch_time: bool = True,
83+
) -> MockStreamArm:
84+
"""Create a new instance of the class."""
85+
# Reshape (*tbatch, T, *ybatch) to (*tbatch, *ybatch, T)
86+
t = soln.ts # already in the shape (*tbatch, T)
87+
n_tbatch = soln.t0.ndim
88+
q = jnp.moveaxis(soln.ys[0], n_tbatch, -2)
89+
p = jnp.moveaxis(soln.ys[1], n_tbatch, -2)
90+
91+
# Reshape (*tbatch, *ybatch, T) to (*tbatch, *ybatch) if T == 1
92+
if unbatch_time and t.shape[-1] == 1:
93+
t = t[..., -1]
94+
q = q[..., -1, :]
95+
p = p[..., -1, :]
96+
97+
# Convert the solution to a phase-space position
98+
return cls(
99+
q=cx.CartesianPos3D.from_(q, units["length"]),
100+
p=cx.CartesianVel3D.from_(p, units["speed"]),
101+
t=BareQuantity(t, units["time"]),
102+
release_time=release_time,
103+
frame=frame,
104+
)
105+
106+
70107
# =========================================================
71108
# `__getitem__`
72109

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Stream Distribution Functions for ejecting mock stream particles."""
2+
3+
__all__ = ["AbstractStreamDF"]
4+
5+
import abc
6+
from typing import TypeAlias
7+
8+
import equinox as eqx
9+
from jaxtyping import PRNGKeyArray
10+
11+
import galax._custom_types as gt
12+
import galax.potential as gp
13+
14+
Carry: TypeAlias = tuple[gt.QuSz3, gt.QuSz3, gt.QuSz3, gt.QuSz3]
15+
16+
17+
class AbstractStreamDF(eqx.Module, strict=True): # type: ignore[call-arg, misc]
18+
"""Abstract base class of Stream Distribution Functions."""
19+
20+
# TODO: keep units and PSP through this func
21+
@abc.abstractmethod
22+
def sample(
23+
self,
24+
key: PRNGKeyArray,
25+
potential: gp.AbstractPotential,
26+
x: gt.BBtQuSz3,
27+
v: gt.BBtQuSz3,
28+
prog_mass: gt.BBtFloatQuSz0,
29+
t: gt.BBtFloatQuSz0,
30+
) -> tuple[gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3]:
31+
"""Generate stream particle initial conditions.
32+
33+
Parameters
34+
----------
35+
rng : :class:`jaxtyping.PRNGKeyArray`
36+
Pseudo-random number generator.
37+
potential : :class:`galax.potential.AbstractPotential`
38+
The potential of the host galaxy.
39+
x : Quantity[float, (*#batch, 3), "length"]
40+
3d position (x, y, z)
41+
v : Quantity[float, (*#batch, 3), "speed"]
42+
3d velocity (v_x, v_y, v_z)
43+
prog_mass : Quantity[float, (*#batch), "mass"]
44+
Mass of the progenitor.
45+
t : Quantity[float, (*#batch), "time"]
46+
The release time of the stream particles.
47+
48+
Returns
49+
-------
50+
x_lead, v_lead: Quantity[float, (*batch, 3), "length" | "speed"]
51+
Position and velocity of the leading arm.
52+
x_trail, v_trail : Quantity[float, (*batch, 3), "length" | "speed"]
53+
Position and velocity of the trailing arm.
54+
"""
55+
...
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""galax: Galactic Dynamix in Jax."""
2+
3+
__all__ = ["Chen24StreamDF"]
4+
5+
6+
import warnings
7+
from functools import partial
8+
from typing import final
9+
10+
import jax
11+
import jax.random as jr
12+
from jaxtyping import PRNGKeyArray
13+
14+
import coordinax as cx
15+
import quaxed.numpy as jnp
16+
17+
import galax._custom_types as gt
18+
import galax.potential as gp
19+
from .df_base import AbstractStreamDF
20+
from galax.dynamics._src.cluster.radius import tidal_radius
21+
from galax.dynamics._src.register_api import specific_angular_momentum
22+
23+
# ============================================================
24+
# Constants
25+
26+
mean = jnp.array([1.6, -30, 0, 1, 20, 0])
27+
28+
cov = jnp.array(
29+
[
30+
[0.1225, 0, 0, 0, -4.9, 0],
31+
[0, 529, 0, 0, 0, 0],
32+
[0, 0, 144, 0, 0, 0],
33+
[0, 0, 0, 0, 0, 0],
34+
[-4.9, 0, 0, 0, 400, 0],
35+
[0, 0, 0, 0, 0, 484],
36+
]
37+
)
38+
39+
# ============================================================
40+
41+
42+
@final
43+
class Chen24StreamDF(AbstractStreamDF):
44+
"""Chen Stream Distribution Function.
45+
46+
A class for representing the Chen+2024 distribution function for
47+
generating stellar streams based on Chen et al. 2024
48+
https://ui.adsabs.harvard.edu/abs/2024arXiv240801496C/abstract
49+
"""
50+
51+
def __init__(self) -> None:
52+
super().__init__()
53+
warnings.warn(
54+
'Currently only the "no progenitor" version '
55+
"of the Chen+24 model is supported!",
56+
RuntimeWarning,
57+
stacklevel=1,
58+
)
59+
60+
@partial(jax.jit, inline=True)
61+
def sample(
62+
self,
63+
key: PRNGKeyArray,
64+
potential: gp.AbstractPotential,
65+
x: gt.BBtQuSz3,
66+
v: gt.BBtQuSz3,
67+
prog_mass: gt.BBtFloatQuSz0,
68+
t: gt.BBtFloatQuSz0,
69+
) -> tuple[gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3]:
70+
"""Generate stream particle initial conditions."""
71+
# Random number generation
72+
73+
# x_new-hat
74+
r = jnp.linalg.vector_norm(x, axis=-1, keepdims=True)
75+
x_new_hat = x / r
76+
77+
# z_new-hat
78+
L_vec = specific_angular_momentum(x, v)
79+
z_new_hat = cx.vecs.normalize_vector(L_vec)
80+
81+
# y_new-hat
82+
phi_vec = v - jnp.sum(v * x_new_hat, axis=-1, keepdims=True) * x_new_hat
83+
y_new_hat = cx.vecs.normalize_vector(phi_vec)
84+
85+
r_tidal = tidal_radius(potential, x, v, mass=prog_mass, t=t)
86+
87+
# Bill Chen: method="cholesky" doesn't work here!
88+
posvel = jr.multivariate_normal(
89+
key, mean, cov, shape=r_tidal.shape, method="svd"
90+
)
91+
92+
Dr = posvel[:, 0] * r_tidal
93+
94+
v_esc = jnp.sqrt(2 * potential.constants["G"] * prog_mass / Dr)
95+
Dv = posvel[:, 3] * v_esc
96+
97+
# convert degrees to radians
98+
phi = posvel[:, 1] * 0.017453292519943295
99+
theta = posvel[:, 2] * 0.017453292519943295
100+
alpha = posvel[:, 4] * 0.017453292519943295
101+
beta = posvel[:, 5] * 0.017453292519943295
102+
103+
ctheta, stheta = jnp.cos(theta), jnp.sin(theta)
104+
cphi, sphi = jnp.cos(phi), jnp.sin(phi)
105+
calpha, salpha = jnp.cos(alpha), jnp.sin(alpha)
106+
cbeta, sbeta = jnp.cos(beta), jnp.sin(beta)
107+
108+
# Trailing arm
109+
x_trail = (
110+
x
111+
+ (Dr * ctheta * cphi)[:, None] * x_new_hat
112+
+ (Dr * ctheta * sphi)[:, None] * y_new_hat
113+
+ (Dr * stheta)[:, None] * z_new_hat
114+
)
115+
v_trail = (
116+
v
117+
+ (Dv * cbeta * calpha)[:, None] * x_new_hat
118+
+ (Dv * cbeta * salpha)[:, None] * y_new_hat
119+
+ (Dv * sbeta)[:, None] * z_new_hat
120+
)
121+
122+
# Leading arm
123+
x_lead = (
124+
x
125+
- (Dr * ctheta * cphi)[:, None] * x_new_hat
126+
- (Dr * ctheta * sphi)[:, None] * y_new_hat
127+
+ (Dr * stheta)[:, None] * z_new_hat
128+
)
129+
v_lead = (
130+
v
131+
- (Dv * cbeta * calpha)[:, None] * x_new_hat
132+
- (Dv * cbeta * salpha)[:, None] * y_new_hat
133+
+ (Dv * sbeta)[:, None] * z_new_hat
134+
)
135+
136+
return x_lead, v_lead, x_trail, v_trail
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""galax: Galactic Dynamix in Jax."""
2+
3+
__all__ = ["Fardal15StreamDF"]
4+
5+
6+
from functools import partial
7+
from typing import final
8+
9+
import jax
10+
import jax.random as jr
11+
from jaxtyping import PRNGKeyArray
12+
13+
import coordinax as cx
14+
import quaxed.numpy as jnp
15+
16+
import galax._custom_types as gt
17+
import galax.potential as gp
18+
from .df_base import AbstractStreamDF
19+
from galax.dynamics._src.api import omega
20+
from galax.dynamics._src.cluster.radius import tidal_radius
21+
22+
# ============================================================
23+
# Constants
24+
25+
kr_bar = 2.0
26+
kvphi_bar = 0.3
27+
28+
kz_bar = 0.0
29+
kvz_bar = 0.0
30+
31+
sigma_kr = 0.5 # TODO: use actual Fardal values
32+
sigma_kvphi = 0.5 # TODO: use actual Fardal values
33+
sigma_kz = 0.5
34+
sigma_kvz = 0.5
35+
36+
# ============================================================
37+
38+
39+
@final
40+
class Fardal15StreamDF(AbstractStreamDF):
41+
"""Fardal Stream Distribution Function.
42+
43+
A class for representing the Fardal+2015 distribution function for
44+
generating stellar streams based on Fardal et al. 2015
45+
https://ui.adsabs.harvard.edu/abs/2015MNRAS.452..301F/abstract
46+
"""
47+
48+
@partial(jax.jit)
49+
def sample(
50+
self,
51+
key: PRNGKeyArray,
52+
potential: gp.AbstractPotential,
53+
x: gt.BBtQuSz3,
54+
v: gt.BBtQuSz3,
55+
prog_mass: gt.BBtFloatQuSz0,
56+
t: gt.BBtFloatQuSz0,
57+
) -> tuple[gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3]:
58+
"""Generate stream particle initial conditions."""
59+
# Random number generation
60+
key1, key2, key3, key4 = jr.split(key, 4)
61+
62+
om = omega(x, v)[..., None]
63+
64+
# r-hat
65+
r_hat = cx.vecs.normalize_vector(x)
66+
67+
r_tidal = tidal_radius(potential, x, v, mass=prog_mass, t=t)[..., None]
68+
v_circ = om * r_tidal # relative velocity
69+
70+
# z-hat
71+
L_vec = jnp.linalg.cross(x, v)
72+
z_hat = cx.vecs.normalize_vector(L_vec)
73+
74+
# phi-hat
75+
phi_vec = v - jnp.sum(v * r_hat, axis=-1, keepdims=True) * r_hat
76+
phi_hat = cx.vecs.normalize_vector(phi_vec)
77+
78+
# k vals
79+
shape = r_tidal.shape
80+
kr_samp = kr_bar + jr.normal(key1, shape) * sigma_kr
81+
kvphi_samp = kr_samp * (kvphi_bar + jr.normal(key2, shape) * sigma_kvphi)
82+
kz_samp = kz_bar + jr.normal(key3, shape) * sigma_kz
83+
kvz_samp = kvz_bar + jr.normal(key4, shape) * sigma_kvz
84+
85+
# Trailing arm
86+
x_trail = x + r_tidal * (kr_samp * r_hat + kz_samp * z_hat)
87+
v_trail = v + v_circ * (kvphi_samp * phi_hat + kvz_samp * z_hat)
88+
89+
# Leading arm
90+
x_lead = x - r_tidal * (kr_samp * r_hat - kz_samp * z_hat)
91+
v_lead = v - v_circ * (kvphi_samp * phi_hat - kvz_samp * z_hat)
92+
93+
return x_lead, v_lead, x_trail, v_trail

0 commit comments

Comments
 (0)