Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
dependencies = [
"astropy>=7.0",
"beartype==0.21.0", # TODO: figure out why 0.22.0 breaks things
"coordinax>=0.23.1",
"coordinax>=0.23.2",
"dataclassish>=0.8.0",
"diffrax>=0.7",
"diffraxtra>=1.5.2",
Expand All @@ -51,7 +51,7 @@
"quaxed>=0.10.2",
"tfp-nightly[jax]>=0.25.0",
"typing-extensions>=4.13.2",
"unxt>=1.7.7",
"unxt>=1.10.3",
"wadler_lindig>=0.1.7",
"xmmutablemap>=0.2.1",
"zeroth>=1.0",
Expand Down
8 changes: 4 additions & 4 deletions src/galax/_interop/galax_interop_astropy/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def parse_to_xyz_t(
# =============================================================================


@dispatch(precedence=1) # type: ignore[call-overload,misc]
@dispatch(precedence=1)
def potential(
pot: gp.AbstractPotential,
xyz: Real[APYQuantity, "*#batch 3"],
Expand All @@ -63,7 +63,7 @@ def potential(
return gp.potential(pot, convert(xyz, FastQ), convert(t, FastQ))


@dispatch(precedence=1) # type: ignore[call-overload,misc]
@dispatch(precedence=1)
def gradient(
pot: gp.AbstractPotential,
xyz: Real[APYQuantity, "*#batch 3"],
Expand All @@ -79,7 +79,7 @@ def gradient(
return gp.gradient(pot, convert(xyz, FastQ), convert(t, FastQ))


@dispatch(precedence=1) # type: ignore[call-overload,misc]
@dispatch(precedence=1)
def density(
pot: gp.AbstractPotential,
xyz: Real[APYQuantity, "*#batch 3"],
Expand All @@ -95,7 +95,7 @@ def density(
return gp.density(pot, convert(xyz, FastQ), convert(t, FastQ))


@dispatch(precedence=1) # type: ignore[call-overload,misc]
@dispatch(precedence=1)
def hessian(
pot: gp.AbstractPotential,
xyz: Real[APYQuantity, "*#batch 3"],
Expand Down
4 changes: 1 addition & 3 deletions src/galax/coordinates/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,7 @@ def __getitem__(
def __pdoc__(self, **kwargs: object) -> wl.AbstractDoc:
return wl.bracketed(
begin=wl.TextDoc(f"{self.__class__.__name__}("),
docs=wl.named_objs(
field_items(self), short_arrays=False, compact_arrays=True
),
docs=wl.named_objs(field_items(self), short_arrays="compact"),
sep=wl.comma,
end=wl.TextDoc(")"),
indent=kwargs.get("indent", 4),
Expand Down
2 changes: 1 addition & 1 deletion src/galax/coordinates/_src/pscs/base_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def getitem(self: AbstractCompositePhaseSpaceCoordinate, key: str, /) -> Any:
# `unxt.uconvert`


@dispatch(precedence=1) # type: ignore[call-overload,misc] # TODO: make precedence=0
@dispatch(precedence=1) # TODO: make precedence=0
def uconvert(
usys: u.AbstractUnitSystem | str, cwt: AbstractCompositePhaseSpaceCoordinate, /
) -> AbstractCompositePhaseSpaceCoordinate:
Expand Down
2 changes: 1 addition & 1 deletion src/galax/coordinates/_src/pscs/base_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class AbstractBasicPhaseSpaceCoordinate(AbstractPhaseSpaceCoordinate):
# `unxt.uconvert` dispatches


@dispatch(precedence=1) # type: ignore[call-overload, misc] # TODO: make precedence=0
@dispatch(precedence=1) # TODO: make precedence=0
def uconvert(
units: u.AbstractUnitSystem | str, wt: AbstractBasicPhaseSpaceCoordinate
) -> AbstractBasicPhaseSpaceCoordinate:
Expand Down
2 changes: 1 addition & 1 deletion src/galax/coordinates/_src/psps/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def vconvert(
# `unxt.uconvert` dispatches


@dispatch(precedence=1) # type: ignore[call-overload, misc] # TODO: make precedence=0
@dispatch(precedence=1) # TODO: make precedence=0
def uconvert(
units: u.AbstractUnitSystem | str, psp: PhaseSpacePosition
) -> PhaseSpacePosition:
Expand Down
5 changes: 2 additions & 3 deletions src/galax/dynamics/_src/parsetime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities for dynamics."""

__all__ = ["parse_time_specification"]
__all__ = ("parse_time_specification",)


from collections.abc import Sequence
Expand All @@ -12,7 +12,6 @@

import quaxed.numpy as jnp
import unxt as u
from unxt._src.units.api import AstropyUnits as Unit


def parse_time_specification(
Expand Down Expand Up @@ -163,7 +162,7 @@ def parse_time_specification(
# -----------------------------------------------


def _parse_to_time_unit(obj: Any, /) -> Unit | None:
def _parse_to_time_unit(obj: Any, /) -> u.AbstractUnit | None:
return u.unitsystem(obj)["time"] if obj is not None else None


Expand Down
11 changes: 4 additions & 7 deletions src/galax/potential/_src/params/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
"""Parameters on a Potential."""

__all__ = [
"ParameterCallable",
"AbstractParameter",
]
__all__ = ("ParameterCallable", "AbstractParameter")

import abc
from typing import Any, Protocol, runtime_checkable

import equinox as eqx

from unxt._src.units.api import AstropyUnits
import unxt as u

import galax._custom_types as gt

Expand All @@ -20,7 +17,7 @@ class ParameterCallable(Protocol):
"""Protocol for a Parameter callable."""

def __call__(
self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **kwargs: Any
self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **kwargs: Any
) -> gt.QuSzAny | gt.SzAny:
"""Compute the parameter value at the given time(s).

Expand Down Expand Up @@ -56,7 +53,7 @@ class AbstractParameter(eqx.Module): # type: ignore[misc]

@abc.abstractmethod
def __call__(
self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **kwargs: Any
self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **kwargs: Any
) -> gt.QuSzAny:
"""Compute the parameter value at the given time(s).

Expand Down
3 changes: 1 addition & 2 deletions src/galax/potential/_src/params/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import unxt as u
from dataclassish.converters import Unless
from unxt._src.units.api import AstropyUnits
from unxt.quantity import AllowValue

import galax._custom_types as gt
Expand Down Expand Up @@ -128,7 +127,7 @@ def __call__(
self,
t: gt.BBtQuSz0 = t0, # noqa: ARG002
*,
ustrip: AstropyUnits | None = None,
ustrip: u.AbstractUnit | None = None,
**__: Any,
) -> gt.QuSzAny:
"""Return the constant parameter value.
Expand Down
5 changes: 2 additions & 3 deletions src/galax/potential/_src/params/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import jax.core

import unxt as u
from unxt._src.units.api import AstropyUnits
from unxt.quantity import AllowValue

import galax._custom_types as gt
Expand Down Expand Up @@ -84,7 +83,7 @@ def __check_init__(self) -> None:

@ft.partial(jax.jit, static_argnames=("ustrip",))
def __call__(
self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **_: Any
self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **_: Any
) -> gt.QuSzAny | gt.SzAny:
"""Return the parameter value.

Expand Down Expand Up @@ -150,7 +149,7 @@ class CustomParameter(AbstractParameter):

@ft.partial(jax.jit, static_argnames=("ustrip",))
def __call__(
self, t: gt.BBtQuSz0, *, ustrip: AstropyUnits | None = None, **kwargs: Any
self, t: gt.BBtQuSz0, *, ustrip: u.AbstractUnit | None = None, **kwargs: Any
) -> gt.QuSzAny | gt.SzAny:
out = self.func(t, **kwargs)
return out if ustrip is None else u.ustrip(AllowValue, ustrip, out)
3 changes: 1 addition & 2 deletions src/galax/potential/_src/xfm/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from plum import dispatch

import unxt as u
from unxt._src.units.api import AstropyUnits
from unxt.quantity import AllowValue

import galax._custom_types as gt
Expand Down Expand Up @@ -172,7 +171,7 @@ def __call__(
self,
t: gt.BBtQorVSz0,
*,
ustrip: AstropyUnits | None = None,
ustrip: u.AbstractUnit | None = None,
**_: Any,
) -> gt.BBtQuSz3:
t = u.ustrip(u.quantity.AllowValue, self.units["time"], t)
Expand Down
35 changes: 24 additions & 11 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading