diff --git a/pyproject.toml b/pyproject.toml index fe5ffdd2..dfd454e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "astropy>=7.0", "beartype==0.21.0", # TODO: figure out why 0.22.0 breaks things - "coordinax>=0.23.1", + "coordinax>=0.23.2", "dataclassish>=0.8.0", "diffrax>=0.7", "diffraxtra>=1.5.2", @@ -51,7 +51,7 @@ "quaxed>=0.10.2", "tfp-nightly[jax]>=0.25.0", "typing-extensions>=4.13.2", - "unxt>=1.7.7", + "unxt>=1.10.3", "wadler_lindig>=0.1.7", "xmmutablemap>=0.2.1", "zeroth>=1.0", diff --git a/src/galax/_interop/galax_interop_astropy/potential.py b/src/galax/_interop/galax_interop_astropy/potential.py index 8cba62f5..13612002 100644 --- a/src/galax/_interop/galax_interop_astropy/potential.py +++ b/src/galax/_interop/galax_interop_astropy/potential.py @@ -47,7 +47,7 @@ def parse_to_xyz_t( # ============================================================================= -@dispatch(precedence=1) # type: ignore[call-overload,misc] +@dispatch(precedence=1) def potential( pot: gp.AbstractPotential, xyz: Real[APYQuantity, "*#batch 3"], @@ -63,7 +63,7 @@ def potential( return gp.potential(pot, convert(xyz, FastQ), convert(t, FastQ)) -@dispatch(precedence=1) # type: ignore[call-overload,misc] +@dispatch(precedence=1) def gradient( pot: gp.AbstractPotential, xyz: Real[APYQuantity, "*#batch 3"], @@ -79,7 +79,7 @@ def gradient( return gp.gradient(pot, convert(xyz, FastQ), convert(t, FastQ)) -@dispatch(precedence=1) # type: ignore[call-overload,misc] +@dispatch(precedence=1) def density( pot: gp.AbstractPotential, xyz: Real[APYQuantity, "*#batch 3"], @@ -95,7 +95,7 @@ def density( return gp.density(pot, convert(xyz, FastQ), convert(t, FastQ)) -@dispatch(precedence=1) # type: ignore[call-overload,misc] +@dispatch(precedence=1) def hessian( pot: gp.AbstractPotential, xyz: Real[APYQuantity, "*#batch 3"], diff --git a/src/galax/coordinates/_src/base.py b/src/galax/coordinates/_src/base.py index ddae89d3..f656e6a2 100644 --- a/src/galax/coordinates/_src/base.py +++ b/src/galax/coordinates/_src/base.py @@ -279,9 +279,7 @@ def __getitem__( def __pdoc__(self, **kwargs: object) -> wl.AbstractDoc: return wl.bracketed( begin=wl.TextDoc(f"{self.__class__.__name__}("), - docs=wl.named_objs( - field_items(self), short_arrays=False, compact_arrays=True - ), + docs=wl.named_objs(field_items(self), short_arrays="compact"), sep=wl.comma, end=wl.TextDoc(")"), indent=kwargs.get("indent", 4), diff --git a/src/galax/coordinates/_src/pscs/base_composite.py b/src/galax/coordinates/_src/pscs/base_composite.py index b84e1823..bbcd5042 100644 --- a/src/galax/coordinates/_src/pscs/base_composite.py +++ b/src/galax/coordinates/_src/pscs/base_composite.py @@ -255,7 +255,7 @@ def getitem(self: AbstractCompositePhaseSpaceCoordinate, key: str, /) -> Any: # `unxt.uconvert` -@dispatch(precedence=1) # type: ignore[call-overload,misc] # TODO: make precedence=0 +@dispatch(precedence=1) # TODO: make precedence=0 def uconvert( usys: u.AbstractUnitSystem | str, cwt: AbstractCompositePhaseSpaceCoordinate, / ) -> AbstractCompositePhaseSpaceCoordinate: diff --git a/src/galax/coordinates/_src/pscs/base_single.py b/src/galax/coordinates/_src/pscs/base_single.py index eeb7e6b9..d156ce61 100644 --- a/src/galax/coordinates/_src/pscs/base_single.py +++ b/src/galax/coordinates/_src/pscs/base_single.py @@ -63,7 +63,7 @@ class AbstractBasicPhaseSpaceCoordinate(AbstractPhaseSpaceCoordinate): # `unxt.uconvert` dispatches -@dispatch(precedence=1) # type: ignore[call-overload, misc] # TODO: make precedence=0 +@dispatch(precedence=1) # TODO: make precedence=0 def uconvert( units: u.AbstractUnitSystem | str, wt: AbstractBasicPhaseSpaceCoordinate ) -> AbstractBasicPhaseSpaceCoordinate: diff --git a/src/galax/coordinates/_src/psps/core.py b/src/galax/coordinates/_src/psps/core.py index dc6add84..afcab21b 100644 --- a/src/galax/coordinates/_src/psps/core.py +++ b/src/galax/coordinates/_src/psps/core.py @@ -309,7 +309,7 @@ def vconvert( # `unxt.uconvert` dispatches -@dispatch(precedence=1) # type: ignore[call-overload, misc] # TODO: make precedence=0 +@dispatch(precedence=1) # TODO: make precedence=0 def uconvert( units: u.AbstractUnitSystem | str, psp: PhaseSpacePosition ) -> PhaseSpacePosition: diff --git a/src/galax/dynamics/_src/parsetime.py b/src/galax/dynamics/_src/parsetime.py index 6cdfa821..47d5c34f 100644 --- a/src/galax/dynamics/_src/parsetime.py +++ b/src/galax/dynamics/_src/parsetime.py @@ -1,6 +1,6 @@ """Utilities for dynamics.""" -__all__ = ["parse_time_specification"] +__all__ = ("parse_time_specification",) from collections.abc import Sequence @@ -12,7 +12,6 @@ import quaxed.numpy as jnp import unxt as u -from unxt._src.units.api import AstropyUnits as Unit def parse_time_specification( @@ -163,7 +162,7 @@ def parse_time_specification( # ----------------------------------------------- -def _parse_to_time_unit(obj: Any, /) -> Unit | None: +def _parse_to_time_unit(obj: Any, /) -> u.AbstractUnit | None: return u.unitsystem(obj)["time"] if obj is not None else None diff --git a/src/galax/potential/_src/params/base.py b/src/galax/potential/_src/params/base.py index 8c388ca2..d6017a65 100644 --- a/src/galax/potential/_src/params/base.py +++ b/src/galax/potential/_src/params/base.py @@ -1,16 +1,13 @@ """Parameters on a Potential.""" -__all__ = [ - "ParameterCallable", - "AbstractParameter", -] +__all__ = ("ParameterCallable", "AbstractParameter") import abc from typing import Any, Protocol, runtime_checkable import equinox as eqx -from unxt._src.units.api import AstropyUnits +import unxt as u import galax._custom_types as gt @@ -20,7 +17,7 @@ class ParameterCallable(Protocol): """Protocol for a Parameter callable.""" def __call__( - self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **kwargs: Any + self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **kwargs: Any ) -> gt.QuSzAny | gt.SzAny: """Compute the parameter value at the given time(s). @@ -56,7 +53,7 @@ class AbstractParameter(eqx.Module): # type: ignore[misc] @abc.abstractmethod def __call__( - self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **kwargs: Any + self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **kwargs: Any ) -> gt.QuSzAny: """Compute the parameter value at the given time(s). diff --git a/src/galax/potential/_src/params/constant.py b/src/galax/potential/_src/params/constant.py index 630e2279..9c5da815 100644 --- a/src/galax/potential/_src/params/constant.py +++ b/src/galax/potential/_src/params/constant.py @@ -16,7 +16,6 @@ import unxt as u from dataclassish.converters import Unless -from unxt._src.units.api import AstropyUnits from unxt.quantity import AllowValue import galax._custom_types as gt @@ -128,7 +127,7 @@ def __call__( self, t: gt.BBtQuSz0 = t0, # noqa: ARG002 *, - ustrip: AstropyUnits | None = None, + ustrip: u.AbstractUnit | None = None, **__: Any, ) -> gt.QuSzAny: """Return the constant parameter value. diff --git a/src/galax/potential/_src/params/core.py b/src/galax/potential/_src/params/core.py index 8e80932b..710dcfba 100644 --- a/src/galax/potential/_src/params/core.py +++ b/src/galax/potential/_src/params/core.py @@ -13,7 +13,6 @@ import jax.core import unxt as u -from unxt._src.units.api import AstropyUnits from unxt.quantity import AllowValue import galax._custom_types as gt @@ -84,7 +83,7 @@ def __check_init__(self) -> None: @ft.partial(jax.jit, static_argnames=("ustrip",)) def __call__( - self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **_: Any + self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **_: Any ) -> gt.QuSzAny | gt.SzAny: """Return the parameter value. @@ -150,7 +149,7 @@ class CustomParameter(AbstractParameter): @ft.partial(jax.jit, static_argnames=("ustrip",)) def __call__( - self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **kwargs: Any + self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **kwargs: Any ) -> gt.QuSzAny | gt.SzAny: out = self.func(t, **kwargs) return out if ustrip is None else u.ustrip(AllowValue, ustrip, out) diff --git a/src/galax/potential/_src/xfm/translate.py b/src/galax/potential/_src/xfm/translate.py index 917d7411..24db76ea 100644 --- a/src/galax/potential/_src/xfm/translate.py +++ b/src/galax/potential/_src/xfm/translate.py @@ -14,7 +14,6 @@ from plum import dispatch import unxt as u -from unxt._src.units.api import AstropyUnits from unxt.quantity import AllowValue import galax._custom_types as gt @@ -172,7 +171,7 @@ def __call__( self, t: gt.BBtQorVSz0, *, - ustrip: AstropyUnits | None = None, + ustrip: u.AbstractUnit | None = None, **_: Any, ) -> gt.BBtQuSz3: t = u.ustrip(u.quantity.AllowValue, self.units["time"], t) diff --git a/uv.lock b/uv.lock index bb1fb86f..03f78721 100644 --- a/uv.lock +++ b/uv.lock @@ -463,7 +463,7 @@ wheels = [ [[package]] name = "coordinax" -version = "0.23.1" +version = "0.23.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "astropy" }, @@ -483,9 +483,9 @@ dependencies = [ { name = "wadler-lindig" }, { name = "xmmutablemap" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7d/19/e916264a6560700b8cc8c422dcc8f148490ff1cc41cdbc2ea0c3205b9a75/coordinax-0.23.1.tar.gz", hash = "sha256:f6a730661872b804c6844a2a8f506af400d1f93d9d903d88938c8d7a629cc542", size = 735648, upload-time = "2025-10-22T20:24:41.689Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/53/4cf963f6d87b26fea3d4e1b1c8308fc28fc6e3466d4e68cfcd6d4516abd1/coordinax-0.23.2.tar.gz", hash = "sha256:8e44e9e148275a2bbd3da8470b37d0c4c441a02452d6b830c1cf3e173a39034e", size = 742870, upload-time = "2026-01-31T19:36:40.521Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/03/d3/58d01dd10c2b3512269c1014cc4c352ea561e331bff5304ab2b2e38abea5/coordinax-0.23.1-py3-none-any.whl", hash = "sha256:564e9d8735c02da81cf5ecac8b074f375ec61e1c895526398d4a3e1921f4e85a", size = 194514, upload-time = "2025-10-22T20:24:39.919Z" }, + { url = "https://files.pythonhosted.org/packages/0a/77/0c41522a7a9e18571e9532602ac2cb8524bca2b6e47d5b54ac8c449035f7/coordinax-0.23.2-py3-none-any.whl", hash = "sha256:402a0330f6f77b763b4729c8adef76e86cba2c154690e315afbdea62d9e9b126", size = 194695, upload-time = "2026-01-31T19:36:38.997Z" }, ] [[package]] @@ -1044,7 +1044,7 @@ requires-dist = [ { name = "astropy", marker = "extra == 'interop-gala'", specifier = ">=6.1" }, { name = "astropy", marker = "extra == 'interop-galpy'", specifier = ">=6.1" }, { name = "beartype", specifier = "==0.21.0" }, - { name = "coordinax", specifier = ">=0.23.1" }, + { name = "coordinax", specifier = ">=0.23.2" }, { name = "dataclassish", specifier = ">=0.8.0" }, { name = "diffrax", specifier = ">=0.7" }, { name = "diffraxtra", specifier = ">=1.5.2" }, @@ -1075,7 +1075,7 @@ requires-dist = [ { name = "quaxed", specifier = ">=0.10.2" }, { name = "tfp-nightly", extras = ["jax"], specifier = ">=0.25.0" }, { name = "typing-extensions", specifier = ">=4.13.2" }, - { name = "unxt", specifier = ">=1.7.7" }, + { name = "unxt", specifier = ">=1.10.3" }, { name = "wadler-lindig", specifier = ">=0.1.7" }, { name = "xmmutablemap", specifier = ">=0.2.1" }, { name = "zeroth", specifier = ">=1.0" }, @@ -2488,7 +2488,7 @@ wheels = [ [[package]] name = "quaxed" -version = "0.10.2" +version = "0.10.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "equinox" }, @@ -2499,9 +2499,9 @@ dependencies = [ { name = "plum-dispatch" }, { name = "quax" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/32/3d/79857c170ff6ddfb5fb590a563c93837aa1e2f41c5016ccbc85e8d5c7c7b/quaxed-0.10.2.tar.gz", hash = "sha256:7260533a5212a109bd61d29e7789c5d7bd1fd16a2a1ab0d2d821b47043945c2e", size = 134153, upload-time = "2025-10-19T15:12:51.663Z" } +sdist = { url = "https://files.pythonhosted.org/packages/be/10/1a8ad41af9e0cee70cedb6091c274a2c466e28df434248ed8adf5ff19198/quaxed-0.10.4.tar.gz", hash = "sha256:8f12df2fc938e39e2a9b809e660a287c6d26c46c700bfcbb5bc2fbd7f1741924", size = 139122, upload-time = "2025-12-05T22:42:30.032Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5a/09/9a54bd735dfefe7c02577345bd68502a9a8111c9cf41d6816c1dfcffb939/quaxed-0.10.2-py3-none-any.whl", hash = "sha256:b8e6453ada2947c66ae62a1df9e47bcda402961318b5719de7d16d7dfe957f8a", size = 28924, upload-time = "2025-10-19T15:12:50.149Z" }, + { url = "https://files.pythonhosted.org/packages/f0/08/04cae4a2357439dbf0f98ceb60855557ed23739f565fbaccbc2e2c93e9ca/quaxed-0.10.4-py3-none-any.whl", hash = "sha256:5153a3aa96969df2c36b6eb31419b794020ecb104202106540020dec0c20b4aa", size = 33943, upload-time = "2025-12-05T22:42:28.559Z" }, ] [[package]] @@ -2921,7 +2921,7 @@ wheels = [ [[package]] name = "unxt" -version = "1.7.7" +version = "1.10.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "astropy" }, @@ -2936,13 +2936,26 @@ dependencies = [ { name = "quax" }, { name = "quax-blocks" }, { name = "quaxed" }, + { name = "unxt-api" }, { name = "wadler-lindig" }, { name = "xmmutablemap" }, { name = "zeroth" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f1/97/ad2f92a851b7efc80e9aa07a9f1a4ffd4d8b43b6dd1c6e01110362e94e53/unxt-1.7.7.tar.gz", hash = "sha256:b41463a0768cc39a96dc5f56748a8e22cb8dc3aff5015bc21054cabc1e1a7e36", size = 886755, upload-time = "2025-11-17T02:50:57.566Z" } +sdist = { url = "https://files.pythonhosted.org/packages/25/2f/78436064bc5112882dcf871f5a51f16a6a2b2c27c7e36e8ca146e51cab34/unxt-1.10.3.tar.gz", hash = "sha256:522318b9d19171f52faeaceb5aff9ceaa73c39ba7f7dfea434b42c1d126bc164", size = 1005379, upload-time = "2026-01-31T17:28:11.699Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0f/81/9816966d2b203a9e8871103aeaea3c3d6b7227a4d8af90c8f354ea27d0b5/unxt-1.7.7-py3-none-any.whl", hash = "sha256:1c5ccf57f4bfd3474e0546733932301d92f26ba65878d85f6345c5541c359943", size = 74432, upload-time = "2025-11-17T02:50:56.014Z" }, + { url = "https://files.pythonhosted.org/packages/15/e7/2635b0f644deb4cf9e92b3d3c10f6230ecde9814b85808b4302c8d7cf7fd/unxt-1.10.3-py3-none-any.whl", hash = "sha256:127c6a66e08126bba40e9f9383af31c089a9112d06eebabdf9eb1cdeff0f1359", size = 83510, upload-time = "2026-01-31T17:28:10.261Z" }, +] + +[[package]] +name = "unxt-api" +version = "1.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "plum-dispatch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/59/55/a9e4f226a3d2f26a8bafc55868b8b8ec290cdf3766d9dc378b3b2dd0eb0c/unxt_api-1.10.1.tar.gz", hash = "sha256:dae10d19edc0b96b225dfa122d0ed5999ad8523d3b0e63f33a3acc2ccdff0cbc", size = 21341, upload-time = "2026-01-24T22:28:06.289Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/63/64/4f4fb51ee7558784a566fb30d8e0d04c689501f7856aa47d279bf6d745cf/unxt_api-1.10.1-py3-none-any.whl", hash = "sha256:7cc4ed4b4d22a9a93fe185add206861a5931f45e892d68d86c8034c445e28ba6", size = 6266, upload-time = "2026-01-24T22:28:05.087Z" }, ] [[package]]