@@ -15,10 +15,15 @@ def __init__(self, xarray_obj):
1515 self ._bbox = {}
1616 self ._field_grids = defaultdict (lambda : None )
1717
18- def load_grid_from_callable (
19- self , fields : Optional [List [str ]] = None , geometry = None , ** kwargs
18+ def _load_uniform_grid (
19+ self ,
20+ * args ,
21+ fields : Optional [List [str ]] = None ,
22+ geometry = None ,
23+ use_callable = True ,
24+ ** kwargs ,
2025 ):
21- print ( "loading uniform grid from callable" )
26+
2227 if geometry is None :
2328 geomtype = _determine_yt_geomtype (self .coord_type , self ._coord_list )
2429 if geomtype is None :
@@ -27,18 +32,6 @@ def load_grid_from_callable(
2732 "geometry = 'geographic' or 'internal_geopgraphic'"
2833 )
2934
30- def _read_data (handle ):
31- def _reader (grid , field_name ):
32- ftype , fname = field_name
33- si = grid .get_global_startindex ()
34- ei = si + grid .ActiveDimensions
35- var = getattr (handle , fname )
36- data = var [si [0 ] : ei [0 ], si [1 ] : ei [1 ], si [2 ] : ei [2 ]]
37- return data .values
38-
39- return _reader
40-
41- reader = _read_data (self ._obj )
4235 if fields is None :
4336 fields = list (self ._obj .data_vars )
4437
@@ -54,32 +47,14 @@ def _reader(grid, field_name):
5447 )
5548 raise NotImplementedError (msg )
5649
57- # again, need to possibly account for stretched grid here... or at
58- # least check for it and raise a warning ...
50+ # need to possibly account for stretched grid here... or at
51+ # least check for it and raise an error ...
5952
6053 # single grid, use whole domain for L/R edges
6154 bbox_vals = self .get_single_bbox (fields )
6255 l_e = bbox_vals [:, 0 ]
6356 r_e = bbox_vals [:, 1 ]
6457
65- data = {}
66- for field in fields :
67- units = getattr (self ._obj .data_vars [field ], "units" , None ) or ""
68- data [field ] = (reader , units )
69-
70- data .update (
71- {
72- "left_edge" : l_e ,
73- "right_edge" : r_e ,
74- "dimensions" : shape ,
75- "level" : 0 ,
76- }
77- )
78-
79- grid_data = [
80- data ,
81- ]
82-
8358 if "length_unit" in kwargs :
8459 length_unit = kwargs .pop ("length_unit" )
8560 else :
@@ -92,59 +67,108 @@ def _reader(grid, field_name):
9267
9368 coord_list = self ._get_yt_coordlist ()
9469 geom = (geomtype , coord_list )
70+ if use_callable :
71+
72+ def _read_data (handle ):
73+ def _reader (grid , field_name ):
74+ ftype , fname = field_name
75+ si = grid .get_global_startindex ()
76+ ei = si + grid .ActiveDimensions
77+ var = getattr (handle , fname )
78+ data = var [si [0 ] : ei [0 ], si [1 ] : ei [1 ], si [2 ] : ei [2 ]]
79+ return data .values
80+
81+ return _reader
82+
83+ reader = _read_data (self ._obj )
84+
85+ data = {}
86+ for field in fields :
87+ units = getattr (self ._obj .data_vars [field ], "units" , None ) or ""
88+ data [field ] = (reader , units )
89+
90+ data .update (
91+ {
92+ "left_edge" : l_e ,
93+ "right_edge" : r_e ,
94+ "dimensions" : shape ,
95+ "level" : 0 ,
96+ }
97+ )
9598
96- return yt .load_amr_grids (
97- grid_data , shape , geometry = geom , bbox = bbox_vals , length_unit = length_unit
98- )
99+ grid_data = [
100+ data ,
101+ ]
102+
103+ return yt .load_amr_grids (
104+ grid_data ,
105+ shape ,
106+ geometry = geom ,
107+ bbox = bbox_vals ,
108+ length_unit = length_unit ,
109+ ** kwargs ,
110+ )
99111
100- def load_uniform_grid (
101- self , fields : List [str ], * args , geometry : Optional [str ] = None , ** kwargs
112+ else :
113+ # should account for stretched grid here!
114+ data = {field : self ._obj [field ].values for field in fields }
115+ return yt .load_uniform_grid (
116+ data ,
117+ shape ,
118+ length_unit = length_unit ,
119+ bbox = bbox_vals ,
120+ geometry = geom ,
121+ ** kwargs ,
122+ )
123+
124+ def load_grid_from_callable (
125+ self , fields : Optional [List [str ]] = None , geometry = None , ** kwargs
102126 ):
103127 """
104- return an in-memory uniform grid yt dataset
128+ returns a uniform grid yt dataset linked to the open xarray handle.
105129
106130 Parameters
107131 ----------
108- fields : list of fields to include
109- args : any additional positional arguments to pass to yt.load_uniform_grid
132+ fields : list of fields to include. If None, will try to use all fields
110133 geometry : the geometry to pass to yt.load_uniform grid. If not provided,
111134 will attempt to infer.
112135 kwargs : any additional keyword arguments to pass to yt.load_uniform_grid
113136
114137 Returns
115138 -------
116- result of yt.load_uniform_grid()
139+ yt StreamDataset
117140
118- """
119- print ("loading uniform grid" )
120- if geometry is None :
121- geomtype = _determine_yt_geomtype (self .coord_type , self ._coord_list )
122- if geomtype is None :
123- raise ValueError (
124- "Cannot determine yt geometry type, please provide"
125- "geometry = 'geographic' or 'internal_geopgraphic'"
126- )
141+ Notes
142+ -----
127143
128- if "length_unit" in kwargs :
129- length_unit = kwargs .pop ("length_unit" )
130- else :
131- length_unit = self ._infer_length_unit ()
132- if length_unit is None :
133- raise ValueError (
134- "cannot determine length_unit, please provide as"
135- "a keyword argument."
136- )
144+ This function relies on the stream callable functionality in yt>=4.1.0
145+ in order to read directly from an open xarray handle without creating
146+ additional in-memory copies of the data.
147+ """
148+ return self ._load_uniform_grid (fields = fields , geometry = geometry , ** kwargs )
137149
138- coord_list = self ._get_yt_coordlist ()
139- geom = (geomtype , coord_list )
140- bbox_vals = self .get_single_bbox (fields ) # will validate field bboxes
150+ def load_uniform_grid (
151+ self ,
152+ fields : Optional [List [str ]] = None ,
153+ geometry : Optional [str ] = None ,
154+ ** kwargs ,
155+ ):
156+ """
157+ return an in-memory uniform grid yt dataset
141158
142- # should account for stretched grid here!
159+ Parameters
160+ ----------
161+ fields : list of fields to include. If None, will try to use all fields
162+ geometry : the geometry to pass to yt.load_uniform grid. If not provided,
163+ will attempt to infer.
164+ kwargs : any additional keyword arguments to pass to yt.load_uniform_grid
143165
144- data = {field : self ._obj [field ].values for field in fields }
145- sizes = data [list (data .keys ())[0 ]].shape
146- return yt .load_uniform_grid (
147- data , sizes , length_unit , * args , bbox = bbox_vals , geometry = geom , ** kwargs
166+ Returns
167+ -------
168+ yt StreamDataset
169+ """
170+ return self ._load_uniform_grid (
171+ fields = fields , geometry = geometry , use_callable = False , ** kwargs
148172 )
149173
150174 def _infer_length_unit (self ):
0 commit comments