Skip to content

Commit f09369b

Browse files
authored
Merge pull request #87 from chrishavlin/fix_indexing_with_time_isel
fix handling of 2d+time datasets with reversed axes
2 parents 81dae43 + 3388da7 commit f09369b

5 files changed

Lines changed: 61 additions & 10 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies=['yt>=4.2.0', 'h5py>=3.4.0', 'pooch>=1.5.1', 'xarray']
2929

3030
[project.optional-dependencies]
3131
full = ["netCDF4", "scipy", "dask[complete]"]
32-
test = ["pytest", "pytest-cov"]
32+
test = ["pytest", "pytest-cov", "cartopy"]
3333
docs = ["Sphinx==7.2.6", "jinja2==3.1.2", "nbsphinx==0.9.3"]
3434

3535
[tool.black]

yt_xarray/accessor/_readers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def _reader(grid, field_name):
3232
# note that the si, ei are exchanged!
3333
si0 = si.copy()
3434
ei0 = ei.copy()
35-
si[idim] = sel_info.global_dims[idim] - ei0[idim]
36-
ei[idim] = sel_info.global_dims[idim] - si0[idim]
35+
si[idim] = sel_info.global_dims_no_time[idim] - ei0[idim]
36+
ei[idim] = sel_info.global_dims_no_time[idim] - si0[idim]
3737

3838
# step 2: this global start index accounts for indexing after any
3939
# subselections on the xarray DataArray are made. might

yt_xarray/accessor/_xr_to_yt.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
self.grid_type = None # one of _GridType members
5555
self.cell_widths: list = None
5656
self.global_dims: list = None
57+
self.time_index_number: int = None
5758
self._process_selection(xr_ds)
5859

5960
self.yt_coord_names = _convert_to_yt_internal_coords(self.selected_coords)
@@ -130,7 +131,8 @@ def _process_selection(self, xr_ds):
130131
reverse_axis = [] # axes must be positive-monitonic for yt
131132
reverse_axis_names = []
132133
global_dims = [] # the global shape
133-
for c in full_coords:
134+
time_index_number = None
135+
for icoord, c in enumerate(full_coords):
134136
coord_da = getattr(xr_ds, c) # the full coordinate data array
135137

136138
# check if coordinate values are increasing
@@ -156,6 +158,8 @@ def _process_selection(self, xr_ds):
156158
sel_or_isel = getattr(coord_da, self.sel_dict_type)
157159
coord_vals = sel_or_isel(coord_select).values.astype(np.float64)
158160
is_time_dim = _check_for_time(c, coord_vals)
161+
if is_time_dim:
162+
time_index_number = icoord
159163

160164
if coord_vals.size > 1:
161165
# not positive-monotonic? reverse it for cell width calculations
@@ -193,6 +197,8 @@ def _process_selection(self, xr_ds):
193197
self.ndims = len(n_edges)
194198
self.selected_shape = tuple(n_edges)
195199
self.select_shape_cells = tuple(n_cells)
200+
if time_index_number is not None:
201+
_ = full_dimranges.pop(time_index_number)
196202
self.full_bbox = np.array(full_dimranges).astype(np.float64)
197203
self.selected_bbox = np.array(dimranges).astype(np.float64)
198204
self.full_coords = tuple(full_coords)
@@ -201,9 +207,15 @@ def _process_selection(self, xr_ds):
201207
self.selected_time = time
202208
self.grid_type = grid_type
203209
self.cell_widths = cell_widths
204-
self.reverse_axis = reverse_axis
205210
self.reverse_axis_names = reverse_axis_names
206211
self.global_dims = np.array(global_dims)
212+
if time_index_number is not None:
213+
_ = global_dims.pop(time_index_number)
214+
_ = reverse_axis.pop(time_index_number)
215+
self.reverse_axis = reverse_axis
216+
self.time_index_number = time_index_number
217+
self.global_dims_no_time = np.array(global_dims)
218+
207219
# self.coord_selected_arrays = coord_selected_arrays
208220

209221
# set the yt grid dictionary

yt_xarray/tests/test_xr_to_yt.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,25 @@ def test_reversed_axis(stretched, use_callable, chunksizes):
533533
pdy_lats = slc._generate_container_field("pdy")
534534
assert np.all(pdy_lats > 0)
535535
assert np.all(np.isfinite(vals))
536+
537+
538+
def test_reader_with_2d_space_time_and_reverse_axis():
539+
540+
# test for https://github.com/data-exp-lab/yt_xarray/issues/86
541+
542+
# a base xarray ds to be used in various places.
543+
ds = construct_ds_with_extra_dim(
544+
3,
545+
ncoords=5,
546+
nd_space=2,
547+
reverse_indices=[
548+
1,
549+
],
550+
)
551+
552+
field = ("stream", "test_case_3")
553+
ds_yt = ds.yt.load_grid(
554+
"test_case_3", sel_dict={"time": 0}, geometry="geographic", use_callable=True
555+
)
556+
slc = yt.SlicePlot(ds_yt, "altitude", field)
557+
assert np.all(np.isfinite(slc.frb[field]))

yt_xarray/utilities/_utilities.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os.path
2-
from typing import Optional, Tuple
2+
from typing import List, Optional, Tuple
33

44
import numpy as np
55
import xarray as xr
@@ -111,7 +111,7 @@ def _test_time_coord(nt=5):
111111

112112
def _get_test_coord(
113113
cname, n, minv: Optional[float] = None, maxv: Optional[float] = None
114-
):
114+
) -> np.ndarray:
115115
if cname in known_coord_aliases:
116116
cname = known_coord_aliases[cname]
117117

@@ -140,7 +140,13 @@ def _get_test_coord(
140140
return np.linspace(minv, maxv, n)
141141

142142

143-
def construct_ds_with_extra_dim(icoord: int, dim_name: str = "time"):
143+
def construct_ds_with_extra_dim(
144+
icoord: int,
145+
dim_name: str = "time",
146+
ncoords: Optional[int] = None,
147+
nd_space: int = 3,
148+
reverse_indices: Optional[List[int]] = None,
149+
):
144150
coord_configs = {
145151
0: (dim_name, "x", "y", "z"),
146152
1: (dim_name, "z", "y", "x"),
@@ -150,11 +156,22 @@ def construct_ds_with_extra_dim(icoord: int, dim_name: str = "time"):
150156
}
151157

152158
data_vars = {}
153-
coords = {c: _get_test_coord(c, icoord + 4) for c in coord_configs[icoord]}
159+
if ncoords is None:
160+
ncoords = icoord + 4
161+
162+
full_dims = coord_configs[icoord]
163+
dim_order = [full_dims[idim] for idim in range(nd_space + 1)]
164+
coords = {c: _get_test_coord(c, ncoords) for c in dim_order}
165+
166+
if reverse_indices is not None:
167+
for indx in reverse_indices:
168+
dim = dim_order[indx]
169+
coords[dim] = coords[dim][::-1]
170+
154171
var_shape = tuple([len(c) for c in coords.values()])
155172
vals = np.random.random(var_shape)
156173
fname = f"test_case_{icoord}"
157-
da = xr.DataArray(vals, coords=coords, dims=coord_configs[icoord])
174+
da = xr.DataArray(vals, coords=coords, dims=dim_order)
158175
data_vars[fname] = da
159176

160177
return xr.Dataset(data_vars=data_vars)

0 commit comments

Comments
 (0)