Feature Request: Support for astropy.units.Quantity types in array annotations
Hi @patrick-kidger, thanks a lot for creating such an excellent library!
I'd love to propose adding support for astropy.units.Quantity types, which would be a natural extension of current capabilities.
In scientific computing (and astronomy in particular) we frequently work with quantities that have both numerical values and physical units (e.g., distances in meters, masses in kilograms, times in seconds). The astropy.units.Quantity class is the de facto standard for handling dimensional quantities in Python astro community.
Currently, jaxtyping provides type safety for array shapes and dtypes, but there's no built-in way to annotate arrays that carry unit information. We are potentially missing out on a great opportunity for dimensional analysis at the type-checking level.
Proposed feature
Add support for Quantity types while preserving jaxtyping's excellent shape and dtype checking, e.g.:
from astropy.units import Quantity
from jaxtyping import Float
def calculate_physics(
velocities: Float[Quantity, "n_particles 3"], # 3D velocities with units
masses: Float[Quantity, " n_particles"], # masses with units
) -> Float[Quantity, " n_particles"]: # kinetic energy with units
...
return energy
more specifically, it would be useful to catch examples like these in type checking:
def bad_physics(
distance: Float[Quantity, " n"] = ..., # expects length units
mass: Float[Quantity, " n"] = ... # expects mass units
):
return distance + mass # Type error: can't add length + mass
I'm not sure how straightforward or practical an implementation would be, but happy to brainstorm together!
Feature Request: Support for astropy.units.Quantity types in array annotations
Hi @patrick-kidger, thanks a lot for creating such an excellent library!
I'd love to propose adding support for
astropy.units.Quantitytypes, which would be a natural extension of current capabilities.In scientific computing (and astronomy in particular) we frequently work with quantities that have both numerical values and physical units (e.g., distances in meters, masses in kilograms, times in seconds). The
astropy.units.Quantityclass is the de facto standard for handling dimensional quantities in Python astro community.Currently, jaxtyping provides type safety for array shapes and dtypes, but there's no built-in way to annotate arrays that carry unit information. We are potentially missing out on a great opportunity for dimensional analysis at the type-checking level.
Proposed feature
Add support for
Quantitytypes while preserving jaxtyping's excellent shape and dtype checking, e.g.:more specifically, it would be useful to catch examples like these in type checking:
I'm not sure how straightforward or practical an implementation would be, but happy to brainstorm together!