Skip to content

Commit 3fe2107

Browse files
committed
Fix a cropping issue with irispy
1 parent 2e27173 commit 3fe2107

2 files changed

Lines changed: 59 additions & 1 deletion

File tree

ndcube/tests/test_ndcube_slice_and_crop.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import numpy as np
44
import pytest
5+
from packaging.version import Version
56

7+
import astropy
68
import astropy.units as u
79
import astropy.wcs
810
from astropy.coordinates import SkyCoord, SpectralCoord
@@ -208,6 +210,33 @@ def test_crop(ndcube_4d_ln_lt_l_t):
208210
helpers.assert_cubes_equal(output, expected)
209211

210212

213+
@pytest.mark.skipif(Version(astropy.__version__) < Version("7.2"), reason="requires astropy>=7.2 for preserve_units")
214+
def test_crop_high_level_coords_with_non_degree_celestial_units():
215+
data = np.arange(10000).reshape(100, 100)
216+
header = fits.Header()
217+
header["NAXIS"] = 2
218+
header["NAXIS1"] = 100
219+
header["NAXIS2"] = 100
220+
header["CTYPE1"] = "RA---TAN"
221+
header["CTYPE2"] = "DEC--TAN"
222+
header["CUNIT1"] = "arcsec"
223+
header["CUNIT2"] = "arcsec"
224+
header["CRPIX1"] = 1
225+
header["CRPIX2"] = 1
226+
header["CRVAL1"] = 0
227+
header["CRVAL2"] = 0
228+
header["CDELT1"] = 1
229+
header["CDELT2"] = 1
230+
cube = NDCube(data, wcs=WCS(header, preserve_units=True))
231+
232+
lower_corner = cube.wcs.pixel_to_world(0, 56)
233+
upper_corner = cube.wcs.pixel_to_world(99, 58)
234+
235+
expected = cube[56:59, 0:100]
236+
output = cube.crop(lower_corner, upper_corner)
237+
helpers.assert_cubes_equal(output, expected)
238+
239+
211240
def test_crop_tuple_non_tuple_input(ndcube_2d_ln_lt):
212241
cube = ndcube_2d_ln_lt
213242
frame = astropy.wcs.utils.wcs_to_celestial_frame(cube.wcs)

ndcube/utils/cube.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import numpy as np
66

77
import astropy.nddata
8+
import astropy.units as u
89
from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS, HighLevelWCSWrapper, SlicedLowLevelWCS
10+
from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values
911

1012
from ndcube.utils import wcs as wcs_utils
1113
from ndcube.utils.exceptions import warn_user
@@ -107,6 +109,33 @@ def sanitize_crop_inputs(points, wcs):
107109
return False, points, wcs
108110

109111

112+
def _high_level_objects_to_pixel_values(low_level_wcs, *world_objects):
113+
"""
114+
Convert high-level world objects to pixel values.
115+
116+
Astropy's high-level WCS path can hand low-level WCSes celestial values in
117+
degrees even when the low-level WCS advertises angular world units such as
118+
arcsec. Normalize any string-based component units to the WCS world-axis
119+
units before calling the low-level inverse transform.
120+
"""
121+
world_values = list(high_level_objects_to_values(*world_objects, low_level_wcs=low_level_wcs))
122+
for i, (_, _, attr) in enumerate(low_level_wcs.world_axis_object_components):
123+
if not isinstance(attr, str):
124+
continue
125+
source_unit_name = attr.rsplit(".", 1)[-1]
126+
target_unit_name = low_level_wcs.world_axis_units[i]
127+
if not target_unit_name:
128+
continue
129+
try:
130+
source_unit = u.Unit(source_unit_name)
131+
target_unit = u.Unit(target_unit_name)
132+
except Exception: # NOQA: BLE001
133+
continue
134+
if source_unit != target_unit and source_unit.is_equivalent(target_unit):
135+
world_values[i] = (world_values[i] * source_unit).to_value(target_unit)
136+
return low_level_wcs.world_to_pixel_values(*world_values)
137+
138+
110139
def get_crop_item_from_points(points, wcs, crop_by_values, keepdims, original_shape):
111140
"""
112141
Find slice item that crops to minimum cube in array-space containing specified world points.
@@ -182,7 +211,7 @@ def get_crop_item_from_points(points, wcs, crop_by_values, keepdims, original_sh
182211
# in the list corresponding to its axis.
183212
# Use the to_pixel methods to preserve fractional indices for future rounding.
184213
point_pixel_indices = (sliced_wcs.world_to_pixel_values(*sliced_point) if crop_by_values
185-
else HighLevelWCSWrapper(sliced_wcs).world_to_pixel(*sliced_point))
214+
else _high_level_objects_to_pixel_values(sliced_wcs, *sliced_point))
186215
# For each pixel axis associated with this point, place the pixel coords for
187216
# that pixel axis into the corresponding list within combined_points_pixel_idx.
188217
if sliced_wcs.pixel_n_dim == 1:

0 commit comments

Comments
 (0)