Skip to content

Commit 2c6f83f

Browse files
authored
⬆️ dep-bump: unxt, coordinax (#784)
1 parent 86ae174 commit 2c6f83f

12 files changed

Lines changed: 44 additions & 40 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
dependencies = [
3030
"astropy>=7.0",
3131
"beartype==0.21.0", # TODO: figure out why 0.22.0 breaks things
32-
"coordinax>=0.23.1",
32+
"coordinax>=0.23.2",
3333
"dataclassish>=0.8.0",
3434
"diffrax>=0.7",
3535
"diffraxtra>=1.5.2",
@@ -51,7 +51,7 @@
5151
"quaxed>=0.10.2",
5252
"tfp-nightly[jax]>=0.25.0",
5353
"typing-extensions>=4.13.2",
54-
"unxt>=1.7.7",
54+
"unxt>=1.10.3",
5555
"wadler_lindig>=0.1.7",
5656
"xmmutablemap>=0.2.1",
5757
"zeroth>=1.0",

src/galax/_interop/galax_interop_astropy/potential.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def parse_to_xyz_t(
4747
# =============================================================================
4848

4949

50-
@dispatch(precedence=1) # type: ignore[call-overload,misc]
50+
@dispatch(precedence=1)
5151
def potential(
5252
pot: gp.AbstractPotential,
5353
xyz: Real[APYQuantity, "*#batch 3"],
@@ -63,7 +63,7 @@ def potential(
6363
return gp.potential(pot, convert(xyz, FastQ), convert(t, FastQ))
6464

6565

66-
@dispatch(precedence=1) # type: ignore[call-overload,misc]
66+
@dispatch(precedence=1)
6767
def gradient(
6868
pot: gp.AbstractPotential,
6969
xyz: Real[APYQuantity, "*#batch 3"],
@@ -79,7 +79,7 @@ def gradient(
7979
return gp.gradient(pot, convert(xyz, FastQ), convert(t, FastQ))
8080

8181

82-
@dispatch(precedence=1) # type: ignore[call-overload,misc]
82+
@dispatch(precedence=1)
8383
def density(
8484
pot: gp.AbstractPotential,
8585
xyz: Real[APYQuantity, "*#batch 3"],
@@ -95,7 +95,7 @@ def density(
9595
return gp.density(pot, convert(xyz, FastQ), convert(t, FastQ))
9696

9797

98-
@dispatch(precedence=1) # type: ignore[call-overload,misc]
98+
@dispatch(precedence=1)
9999
def hessian(
100100
pot: gp.AbstractPotential,
101101
xyz: Real[APYQuantity, "*#batch 3"],

src/galax/coordinates/_src/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,7 @@ def __getitem__(
279279
def __pdoc__(self, **kwargs: object) -> wl.AbstractDoc:
280280
return wl.bracketed(
281281
begin=wl.TextDoc(f"{self.__class__.__name__}("),
282-
docs=wl.named_objs(
283-
field_items(self), short_arrays=False, compact_arrays=True
284-
),
282+
docs=wl.named_objs(field_items(self), short_arrays="compact"),
285283
sep=wl.comma,
286284
end=wl.TextDoc(")"),
287285
indent=kwargs.get("indent", 4),

src/galax/coordinates/_src/pscs/base_composite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def getitem(self: AbstractCompositePhaseSpaceCoordinate, key: str, /) -> Any:
255255
# `unxt.uconvert`
256256

257257

258-
@dispatch(precedence=1) # type: ignore[call-overload,misc] # TODO: make precedence=0
258+
@dispatch(precedence=1) # TODO: make precedence=0
259259
def uconvert(
260260
usys: u.AbstractUnitSystem | str, cwt: AbstractCompositePhaseSpaceCoordinate, /
261261
) -> AbstractCompositePhaseSpaceCoordinate:

src/galax/coordinates/_src/pscs/base_single.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class AbstractBasicPhaseSpaceCoordinate(AbstractPhaseSpaceCoordinate):
6363
# `unxt.uconvert` dispatches
6464

6565

66-
@dispatch(precedence=1) # type: ignore[call-overload, misc] # TODO: make precedence=0
66+
@dispatch(precedence=1) # TODO: make precedence=0
6767
def uconvert(
6868
units: u.AbstractUnitSystem | str, wt: AbstractBasicPhaseSpaceCoordinate
6969
) -> AbstractBasicPhaseSpaceCoordinate:

src/galax/coordinates/_src/psps/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def vconvert(
309309
# `unxt.uconvert` dispatches
310310

311311

312-
@dispatch(precedence=1) # type: ignore[call-overload, misc] # TODO: make precedence=0
312+
@dispatch(precedence=1) # TODO: make precedence=0
313313
def uconvert(
314314
units: u.AbstractUnitSystem | str, psp: PhaseSpacePosition
315315
) -> PhaseSpacePosition:

src/galax/dynamics/_src/parsetime.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Utilities for dynamics."""
22

3-
__all__ = ["parse_time_specification"]
3+
__all__ = ("parse_time_specification",)
44

55

66
from collections.abc import Sequence
@@ -12,7 +12,6 @@
1212

1313
import quaxed.numpy as jnp
1414
import unxt as u
15-
from unxt._src.units.api import AstropyUnits as Unit
1615

1716

1817
def parse_time_specification(
@@ -163,7 +162,7 @@ def parse_time_specification(
163162
# -----------------------------------------------
164163

165164

166-
def _parse_to_time_unit(obj: Any, /) -> Unit | None:
165+
def _parse_to_time_unit(obj: Any, /) -> u.AbstractUnit | None:
167166
return u.unitsystem(obj)["time"] if obj is not None else None
168167

169168

src/galax/potential/_src/params/base.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
"""Parameters on a Potential."""
22

3-
__all__ = [
4-
"ParameterCallable",
5-
"AbstractParameter",
6-
]
3+
__all__ = ("ParameterCallable", "AbstractParameter")
74

85
import abc
96
from typing import Any, Protocol, runtime_checkable
107

118
import equinox as eqx
129

13-
from unxt._src.units.api import AstropyUnits
10+
import unxt as u
1411

1512
import galax._custom_types as gt
1613

@@ -20,7 +17,7 @@ class ParameterCallable(Protocol):
2017
"""Protocol for a Parameter callable."""
2118

2219
def __call__(
23-
self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **kwargs: Any
20+
self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **kwargs: Any
2421
) -> gt.QuSzAny | gt.SzAny:
2522
"""Compute the parameter value at the given time(s).
2623
@@ -56,7 +53,7 @@ class AbstractParameter(eqx.Module): # type: ignore[misc]
5653

5754
@abc.abstractmethod
5855
def __call__(
59-
self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **kwargs: Any
56+
self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **kwargs: Any
6057
) -> gt.QuSzAny:
6158
"""Compute the parameter value at the given time(s).
6259

src/galax/potential/_src/params/constant.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import unxt as u
1818
from dataclassish.converters import Unless
19-
from unxt._src.units.api import AstropyUnits
2019
from unxt.quantity import AllowValue
2120

2221
import galax._custom_types as gt
@@ -128,7 +127,7 @@ def __call__(
128127
self,
129128
t: gt.BBtQuSz0 = t0, # noqa: ARG002
130129
*,
131-
ustrip: AstropyUnits | None = None,
130+
ustrip: u.AbstractUnit | None = None,
132131
**__: Any,
133132
) -> gt.QuSzAny:
134133
"""Return the constant parameter value.

src/galax/potential/_src/params/core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import jax.core
1414

1515
import unxt as u
16-
from unxt._src.units.api import AstropyUnits
1716
from unxt.quantity import AllowValue
1817

1918
import galax._custom_types as gt
@@ -84,7 +83,7 @@ def __check_init__(self) -> None:
8483

8584
@ft.partial(jax.jit, static_argnames=("ustrip",))
8685
def __call__(
87-
self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **_: Any
86+
self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **_: Any
8887
) -> gt.QuSzAny | gt.SzAny:
8988
"""Return the parameter value.
9089
@@ -150,7 +149,7 @@ class CustomParameter(AbstractParameter):
150149

151150
@ft.partial(jax.jit, static_argnames=("ustrip",))
152151
def __call__(
153-
self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **kwargs: Any
152+
self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **kwargs: Any
154153
) -> gt.QuSzAny | gt.SzAny:
155154
out = self.func(t, **kwargs)
156155
return out if ustrip is None else u.ustrip(AllowValue, ustrip, out)

0 commit comments

Comments
 (0)