Skip to content

Commit 1953576

Browse files
committed
Support collecting from data files with Z guards
1 parent 0e06267 commit 1953576

2 files changed

Lines changed: 164 additions & 16 deletions

File tree

src/boutdata/collect.py

Lines changed: 162 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def collect(
130130
strict=False,
131131
tind_auto=False,
132132
datafile_cache=None,
133+
zguards=False,
133134
):
134135
"""Collect a variable from a set of BOUT++ outputs.
135136
@@ -206,6 +207,7 @@ def getDataFile(i):
206207
prefix,
207208
strict,
208209
datafile_cache,
210+
zguards,
209211
)
210212

211213
nfiles = len(file_list)
@@ -216,6 +218,7 @@ def getDataFile(i):
216218
f,
217219
xguards=xguards,
218220
yguards=yguards,
221+
zguards=zguards,
219222
tind=tind,
220223
xind=xind,
221224
yind=yind,
@@ -259,14 +262,17 @@ def getDataFile(i):
259262

260263
if info:
261264
print(
262-
"mxsub = {} mysub = {} mz = {}\n".format(
263-
grid_info["mxsub"], grid_info["mysub"], grid_info["nz"]
265+
"mxsub = {} mysub = {} mzsub = {}\n".format(
266+
grid_info["mxsub"], grid_info["mysub"], grid_info["mzsub"]
264267
)
265268
)
266269

267270
print(
268-
"nxpe = {}, nype = {}, npes = {}\n".format(
269-
grid_info["nxpe"], grid_info["nype"], grid_info["npes"]
271+
"nxpe = {}, nype = {}, nzpe = {} npes = {}\n".format(
272+
grid_info["nxpe"],
273+
grid_info["nype"],
274+
grid_info["nzpe"],
275+
grid_info["npes"],
270276
)
271277
)
272278
if grid_info["npes"] < nfiles:
@@ -316,6 +322,7 @@ def getDataFile(i):
316322
zind=zind,
317323
xguards=xguards,
318324
yguards=(yguards is not False),
325+
zguards=zguards,
319326
info=info,
320327
)
321328
if is_fieldperp:
@@ -342,6 +349,7 @@ def getDataFile(i):
342349
# Finished looping over all files
343350
if info:
344351
sys.stdout.write("\n")
352+
345353
return BoutArray(data, attributes=var_attributes)
346354

347355

@@ -359,6 +367,7 @@ def _collect_from_single_file(
359367
prefix,
360368
strict,
361369
datafile_cache,
370+
zguards,
362371
):
363372
"""
364373
Collect data from a single file
@@ -394,6 +403,11 @@ def _collect_from_single_file(
394403
except KeyError:
395404
myg = 0
396405
print(f"MYG not found, setting to {myg}")
406+
try:
407+
mzg = f["MZG"]
408+
except KeyError:
409+
mzg = 0
410+
print(f"MZG not found, setting to {mzg}")
397411

398412
if xguards:
399413
nx = f["nx"]
@@ -407,7 +421,11 @@ def _collect_from_single_file(
407421
ny = ny + 2 * myg
408422
else:
409423
ny = f["ny"]
410-
nz = f["MZ"]
424+
425+
if zguards:
426+
nz = f["nz"] + 2 * mzg
427+
else:
428+
nz = f["nz"]
411429
t_array = f.read("t_array")
412430
if t_array is None:
413431
nt = 1
@@ -429,6 +447,8 @@ def _collect_from_single_file(
429447
xind = slice(xind.start + mxg, xind.stop + mxg, xind.step)
430448
if not yguards:
431449
yind = slice(yind.start + myg, yind.stop + myg, yind.step)
450+
if not zguards:
451+
zind = slice(zind.start + mzg, zind.stop + mzg, zind.step)
432452

433453
dim_ranges = {"t": tind, "x": xind, "y": yind, "z": zind}
434454
ranges = [dim_ranges.get(dim, None) for dim in dimensions]
@@ -510,6 +530,7 @@ def _collect_from_one_proc(
510530
zind,
511531
xguards,
512532
yguards,
533+
zguards,
513534
info,
514535
parallel_read=False,
515536
):
@@ -594,15 +615,20 @@ def _collect_from_one_proc(
594615

595616
nxpe = grid_info["nxpe"]
596617
nype = grid_info["nype"]
618+
nzpe = grid_info["nzpe"]
597619
mxsub = grid_info["mxsub"]
598620
mysub = grid_info["mysub"]
621+
mzsub = grid_info["mzsub"]
599622
mxg = grid_info["mxg"]
600623
myg = grid_info["myg"]
624+
mzg = grid_info["mzg"]
601625
yproc_upper_target = grid_info["yproc_upper_target"]
602626

603-
# Get X and Y processor indices
604-
pe_yind = i // nxpe
605-
pe_xind = i % nxpe
627+
# Get processor indices. `grid_info` only has global data, whereas these are
628+
# specific to each file
629+
pe_xind = datafile.read("PE_XIND") or i % nxpe
630+
pe_yind = datafile.read("PE_YIND") or (i // nxpe) % nype
631+
pe_zind = datafile.read("PE_ZIND") or i // (nxpe * nype)
606632

607633
inrange = True
608634

@@ -624,18 +650,38 @@ def _collect_from_one_proc(
624650
yguards, yind, pe_yind, nype, yproc_upper_target, mysub, myg, inrange
625651
)
626652

653+
is_field2d = dimensions == ("t", "x", "y") or dimensions == ("x", "y")
654+
if is_field2d:
655+
# Field2Ds do not have a z-dimension, so cannot be sliced in z and should
656+
# always be read regardless of the value of zind (so we should not change
657+
# inrange by checking the z-range).
658+
# zstart, zstop, zgstart and zgstop are set only to avoid errors in 'info'
659+
# messages.
660+
zstart = 0
661+
zstop = 1
662+
zgstart = 0
663+
zgstop = 1
664+
else:
665+
zstart, zstop, zgstart, zgstop, inrange = _get_z_range(
666+
zguards, zind, pe_zind, nzpe, mzsub, mzg, inrange
667+
)
668+
627669
if not inrange:
628670
return None, None # Don't need this file
629671

630672
local_dim_slices = {
631673
"t": tind,
632674
"x": slice(xstart, xstop),
633675
"y": slice(ystart, ystop),
634-
"z": zind,
676+
"z": slice(zstart, zstop),
635677
}
636678
local_slices = tuple(local_dim_slices.get(dim, None) for dim in dimensions)
637679

638-
global_dim_slices = {"x": slice(xgstart, xgstop), "y": slice(ygstart, ygstop)}
680+
global_dim_slices = {
681+
"x": slice(xgstart, xgstop),
682+
"y": slice(ygstart, ygstop),
683+
"z": slice(zgstart, zgstop),
684+
}
639685
if parallel_read:
640686
# When reading in parallel, we are always reading into a 4-dimensional shared
641687
# array. Should not reach this function unless we only have dimensions in
@@ -652,7 +698,8 @@ def _collect_from_one_proc(
652698

653699
if info:
654700
print(
655-
f"\rReading from {i}: [{xstart}-{xstop - 1}][{ystart}-{ystop - 1}] -> [{xgstart}-{xgstop - 1}][{ygstart}-{ygstop - 1}]\n"
701+
f"\rReading from {i}: [{xstart}-{xstop - 1}][{ystart}-{ystop - 1}][{zstart}-{zstop - 1}] "
702+
f"-> [{xgstart}-{xgstop - 1}][{ygstart}-{ygstop - 1}][{zgstart}-{zgstop - 1}]\n"
656703
)
657704

658705
if is_fieldperp:
@@ -684,8 +731,7 @@ def _fieldperp_from_this(nype, pe_yind, mysub, myg, temp_yindex):
684731

685732
def _check_local_range_lower(start, stop, lower_index, inrange):
686733
"""
687-
Utility function for _get_x_range and _get_y_range. Checks inner or lower edge of
688-
local ranges.
734+
Utility function for `_get_{x,y,z}_range`. Checks inner or lower edge of local ranges.
689735
690736
Parameters
691737
----------
@@ -916,6 +962,83 @@ def _get_y_range(yguards, yind, pe_yind, nype, yproc_upper_target, mysub, myg, i
916962
return ystart, ystop, ygstart, ygstop, inrange
917963

918964

965+
def _get_z_range(zguards, zind, pe_zind, nzpe, mzsub, mzg, inrange):
966+
"""
967+
Get local ranges of z-indices
968+
969+
Parameters
970+
----------
971+
zguards : bool
972+
Include z-boundaries?
973+
zind : slice
974+
Global slice to apply to z-dimension
975+
pe_zind : int
976+
z-indez of the processor
977+
nzpe : int
978+
Number of processors in the z-direction
979+
mzsub : int
980+
Number of grid cells (excluding guard cells) in the z-direction on a single
981+
procssor
982+
mzg : int
983+
Number of guard cells in the z-direction
984+
inrange : bool
985+
Does the processor have data to read?
986+
987+
Returns
988+
-------
989+
zstart : int
990+
Local z-index to start reading
991+
zstop : int
992+
Local z-index to stop reading
993+
zgstart : int
994+
Global z-index to start putting data
995+
zgstop : int
996+
Global z-index to stop putting data
997+
inrange : bool
998+
Updated version of inrange - changed to False if this processor has no data to
999+
read
1000+
"""
1001+
# Local ranges
1002+
if zguards:
1003+
zstart = zind.start - pe_zind * mzsub
1004+
zstop = zind.stop - pe_zind * mzsub
1005+
1006+
# Check lower z boundary
1007+
if pe_zind == 0:
1008+
# Keeping inner boundary
1009+
zstart, inrange = _check_local_range_lower(zstart, zstop, 0, inrange)
1010+
else:
1011+
zstart, inrange = _check_local_range_lower(zstart, zstop, mzg, inrange)
1012+
1013+
# Upper z boundary
1014+
if pe_zind == (nzpe - 1):
1015+
# Keeping outer boundary
1016+
zstop, inrange = _check_local_range_upper(
1017+
zstart, zstop, mzsub + 2 * mzg, inrange
1018+
)
1019+
else:
1020+
zstop, inrange = _check_local_range_upper(
1021+
zstart, zstop, mzsub + mzg, inrange
1022+
)
1023+
1024+
else:
1025+
zstart = zind.start - pe_zind * mzsub + mzg
1026+
zstop = zind.stop - pe_zind * mzsub + mzg
1027+
1028+
zstart, inrange = _check_local_range_lower(zstart, zstop, mzg, inrange)
1029+
zstop, inrange = _check_local_range_upper(zstart, zstop, mzsub + mzg, inrange)
1030+
1031+
# Global ranges
1032+
if zguards:
1033+
zgstart = zstart + pe_zind * mzsub - zind.start
1034+
zgstop = zstop + pe_zind * mzsub - zind.start
1035+
else:
1036+
zgstart = zstart + pe_zind * mzsub - mzg - zind.start
1037+
zgstop = zstop + pe_zind * mzsub - mzg - zind.start
1038+
1039+
return zstart, zstop, zgstart, zgstop, inrange
1040+
1041+
9191042
def _check_fieldperp_attributes(
9201043
varname,
9211044
yindex_global,
@@ -950,7 +1073,17 @@ def _check_fieldperp_attributes(
9501073

9511074

9521075
def _get_grid_info(
953-
f, *, xguards, yguards, tind, xind, yind, zind, nfiles, all_vars_info=False
1076+
f,
1077+
*,
1078+
xguards,
1079+
yguards,
1080+
zguards: bool,
1081+
tind,
1082+
xind,
1083+
yind,
1084+
zind,
1085+
nfiles,
1086+
all_vars_info=False,
9541087
):
9551088
"""Get the grid info from an open DataFile
9561089
@@ -993,8 +1126,10 @@ def load_and_check(varname):
9931126

9941127
mxg = int(load_and_check("MXG"))
9951128
myg = int(load_and_check("MYG"))
1129+
mzg = int(f.read("MZG") or 0)
9961130
mxsub = int(load_and_check("MXSUB"))
9971131
mysub = int(load_and_check("MYSUB"))
1132+
mzsub = int(f.read("MZSUB") or mz)
9981133
try:
9991134
nxpe = int(f["NXPE"])
10001135
except KeyError:
@@ -1006,6 +1141,9 @@ def load_and_check(varname):
10061141
nype = nfiles
10071142
print(f"NYPE not found, setting to {nype}")
10081143

1144+
# Don't warn, most files won't have this
1145+
nzpe = int(f.get("NZPE", 1))
1146+
10091147
if "t_array" in f.keys():
10101148
nt = len(f.read("t_array"))
10111149
else:
@@ -1031,7 +1169,12 @@ def load_and_check(varname):
10311169
else:
10321170
ny = mysub * nype
10331171

1034-
nz = mz - 1 if version < 3.5 else mz
1172+
if zguards:
1173+
nz = mzsub * nzpe + 2 * mzg
1174+
elif version < 3.5:
1175+
nz = mz - 1
1176+
else:
1177+
nz = mzsub * nzpe
10351178

10361179
tind = _convert_to_nice_slice(tind, nt, "tind")
10371180
xind = _convert_to_nice_slice(xind, nx, "xind")
@@ -1053,13 +1196,16 @@ def load_and_check(varname):
10531196
"mxsub": mxsub,
10541197
"myg": myg,
10551198
"mysub": mysub,
1199+
"mzg": mzg,
1200+
"mzsub": mzsub,
10561201
"nt": nt,
1057-
"npes": nxpe * nype,
1202+
"npes": nxpe * nype * nzpe,
10581203
"nx": nx,
10591204
"nxpe": nxpe,
10601205
"ny": ny,
10611206
"nype": nype,
10621207
"nz": nz,
1208+
"nzpe": nzpe,
10631209
"sizes": sizes,
10641210
"varNames": varNames,
10651211
"yproc_upper_target": yproc_upper_target,

src/boutdata/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,7 @@ def __init__(
11111111
f,
11121112
xguards=self._xguards,
11131113
yguards=self._yguards,
1114+
zguards=False,
11141115
tind=tind,
11151116
xind=xind,
11161117
yind=yind,
@@ -1667,6 +1668,7 @@ def _worker_function(self, connection, proc_list, shared_buffer_raw):
16671668
zind=self.zind,
16681669
xguards=self._xguards,
16691670
yguards=self._yguards,
1671+
zguards=False,
16701672
info=self._info,
16711673
parallel_read=True,
16721674
)

0 commit comments

Comments
 (0)