Skip to content

Commit 4bd0240

Browse files
committed
condense the two load methods
1 parent 34de3b0 commit 4bd0240

1 file changed

Lines changed: 94 additions & 70 deletions

File tree

yt_xarray/accessor/accessor.py

Lines changed: 94 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@ def __init__(self, xarray_obj):
1515
self._bbox = {}
1616
self._field_grids = defaultdict(lambda: None)
1717

18-
def load_grid_from_callable(
19-
self, fields: Optional[List[str]] = None, geometry=None, **kwargs
18+
def _load_uniform_grid(
19+
self,
20+
*args,
21+
fields: Optional[List[str]] = None,
22+
geometry=None,
23+
use_callable=True,
24+
**kwargs,
2025
):
21-
print("loading uniform grid from callable")
26+
2227
if geometry is None:
2328
geomtype = _determine_yt_geomtype(self.coord_type, self._coord_list)
2429
if geomtype is None:
@@ -27,18 +32,6 @@ def load_grid_from_callable(
2732
"geometry = 'geographic' or 'internal_geopgraphic'"
2833
)
2934

30-
def _read_data(handle):
31-
def _reader(grid, field_name):
32-
ftype, fname = field_name
33-
si = grid.get_global_startindex()
34-
ei = si + grid.ActiveDimensions
35-
var = getattr(handle, fname)
36-
data = var[si[0] : ei[0], si[1] : ei[1], si[2] : ei[2]]
37-
return data.values
38-
39-
return _reader
40-
41-
reader = _read_data(self._obj)
4235
if fields is None:
4336
fields = list(self._obj.data_vars)
4437

@@ -54,32 +47,14 @@ def _reader(grid, field_name):
5447
)
5548
raise NotImplementedError(msg)
5649

57-
# again, need to possibly account for stretched grid here... or at
58-
# least check for it and raise a warning...
50+
# need to possibly account for stretched grid here... or at
51+
# least check for it and raise an error...
5952

6053
# single grid, use whole domain for L/R edges
6154
bbox_vals = self.get_single_bbox(fields)
6255
l_e = bbox_vals[:, 0]
6356
r_e = bbox_vals[:, 1]
6457

65-
data = {}
66-
for field in fields:
67-
units = getattr(self._obj.data_vars[field], "units", None) or ""
68-
data[field] = (reader, units)
69-
70-
data.update(
71-
{
72-
"left_edge": l_e,
73-
"right_edge": r_e,
74-
"dimensions": shape,
75-
"level": 0,
76-
}
77-
)
78-
79-
grid_data = [
80-
data,
81-
]
82-
8358
if "length_unit" in kwargs:
8459
length_unit = kwargs.pop("length_unit")
8560
else:
@@ -92,59 +67,108 @@ def _reader(grid, field_name):
9267

9368
coord_list = self._get_yt_coordlist()
9469
geom = (geomtype, coord_list)
70+
if use_callable:
71+
72+
def _read_data(handle):
73+
def _reader(grid, field_name):
74+
ftype, fname = field_name
75+
si = grid.get_global_startindex()
76+
ei = si + grid.ActiveDimensions
77+
var = getattr(handle, fname)
78+
data = var[si[0] : ei[0], si[1] : ei[1], si[2] : ei[2]]
79+
return data.values
80+
81+
return _reader
82+
83+
reader = _read_data(self._obj)
84+
85+
data = {}
86+
for field in fields:
87+
units = getattr(self._obj.data_vars[field], "units", None) or ""
88+
data[field] = (reader, units)
89+
90+
data.update(
91+
{
92+
"left_edge": l_e,
93+
"right_edge": r_e,
94+
"dimensions": shape,
95+
"level": 0,
96+
}
97+
)
9598

96-
return yt.load_amr_grids(
97-
grid_data, shape, geometry=geom, bbox=bbox_vals, length_unit=length_unit
98-
)
99+
grid_data = [
100+
data,
101+
]
102+
103+
return yt.load_amr_grids(
104+
grid_data,
105+
shape,
106+
geometry=geom,
107+
bbox=bbox_vals,
108+
length_unit=length_unit,
109+
**kwargs,
110+
)
99111

100-
def load_uniform_grid(
101-
self, fields: List[str], *args, geometry: Optional[str] = None, **kwargs
112+
else:
113+
# should account for stretched grid here!
114+
data = {field: self._obj[field].values for field in fields}
115+
return yt.load_uniform_grid(
116+
data,
117+
shape,
118+
length_unit=length_unit,
119+
bbox=bbox_vals,
120+
geometry=geom,
121+
**kwargs,
122+
)
123+
124+
def load_grid_from_callable(
125+
self, fields: Optional[List[str]] = None, geometry=None, **kwargs
102126
):
103127
"""
104-
return an in-memory uniform grid yt dataset
128+
returns a uniform grid yt dataset linked to the open xarray handle.
105129
106130
Parameters
107131
----------
108-
fields : list of fields to include
109-
args : any additional positional arguments to pass to yt.load_uniform_grid
132+
fields : list of fields to include. If None, will try to use all fields
110133
geometry : the geometry to pass to yt.load_uniform grid. If not provided,
111134
will attempt to infer.
112135
kwargs : any additional keyword arguments to pass to yt.load_uniform_grid
113136
114137
Returns
115138
-------
116-
result of yt.load_uniform_grid()
139+
yt StreamDataset
117140
118-
"""
119-
print("loading uniform grid")
120-
if geometry is None:
121-
geomtype = _determine_yt_geomtype(self.coord_type, self._coord_list)
122-
if geomtype is None:
123-
raise ValueError(
124-
"Cannot determine yt geometry type, please provide"
125-
"geometry = 'geographic' or 'internal_geopgraphic'"
126-
)
141+
Notes
142+
-----
127143
128-
if "length_unit" in kwargs:
129-
length_unit = kwargs.pop("length_unit")
130-
else:
131-
length_unit = self._infer_length_unit()
132-
if length_unit is None:
133-
raise ValueError(
134-
"cannot determine length_unit, please provide as"
135-
"a keyword argument."
136-
)
144+
This function relies on the stream callable functionality in yt>=4.1.0
145+
in order to read directly from an open xarray handle without creating
146+
additional in-memory copies of the data.
147+
"""
148+
return self._load_uniform_grid(fields=fields, geometry=geometry, **kwargs)
137149

138-
coord_list = self._get_yt_coordlist()
139-
geom = (geomtype, coord_list)
140-
bbox_vals = self.get_single_bbox(fields) # will validate field bboxes
150+
def load_uniform_grid(
151+
self,
152+
fields: Optional[List[str]] = None,
153+
geometry: Optional[str] = None,
154+
**kwargs,
155+
):
156+
"""
157+
return an in-memory uniform grid yt dataset
141158
142-
# should account for stretched grid here!
159+
Parameters
160+
----------
161+
fields : list of fields to include. If None, will try to use all fields
162+
geometry : the geometry to pass to yt.load_uniform grid. If not provided,
163+
will attempt to infer.
164+
kwargs : any additional keyword arguments to pass to yt.load_uniform_grid
143165
144-
data = {field: self._obj[field].values for field in fields}
145-
sizes = data[list(data.keys())[0]].shape
146-
return yt.load_uniform_grid(
147-
data, sizes, length_unit, *args, bbox=bbox_vals, geometry=geom, **kwargs
166+
Returns
167+
-------
168+
yt StreamDataset
169+
"""
170+
return self._load_uniform_grid(
171+
fields=fields, geometry=geometry, use_callable=False, **kwargs
148172
)
149173

150174
def _infer_length_unit(self):

0 commit comments

Comments
 (0)