@@ -130,6 +130,7 @@ def collect(
130130 strict = False ,
131131 tind_auto = False ,
132132 datafile_cache = None ,
133+ zguards = False ,
133134):
134135 """Collect a variable from a set of BOUT++ outputs.
135136
@@ -206,6 +207,7 @@ def getDataFile(i):
206207 prefix ,
207208 strict ,
208209 datafile_cache ,
210+ zguards ,
209211 )
210212
211213 nfiles = len (file_list )
@@ -216,6 +218,7 @@ def getDataFile(i):
216218 f ,
217219 xguards = xguards ,
218220 yguards = yguards ,
221+ zguards = zguards ,
219222 tind = tind ,
220223 xind = xind ,
221224 yind = yind ,
@@ -259,14 +262,17 @@ def getDataFile(i):
259262
260263 if info :
261264 print (
262- "mxsub = {} mysub = {} mz = {}\n " .format (
263- grid_info ["mxsub" ], grid_info ["mysub" ], grid_info ["nz " ]
265+ "mxsub = {} mysub = {} mzsub = {}\n " .format (
266+ grid_info ["mxsub" ], grid_info ["mysub" ], grid_info ["mzsub " ]
264267 )
265268 )
266269
267270 print (
268- "nxpe = {}, nype = {}, npes = {}\n " .format (
269- grid_info ["nxpe" ], grid_info ["nype" ], grid_info ["npes" ]
271+ "nxpe = {}, nype = {}, nzpe = {} npes = {}\n " .format (
272+ grid_info ["nxpe" ],
273+ grid_info ["nype" ],
274+ grid_info ["nzpe" ],
275+ grid_info ["npes" ],
270276 )
271277 )
272278 if grid_info ["npes" ] < nfiles :
@@ -316,6 +322,7 @@ def getDataFile(i):
316322 zind = zind ,
317323 xguards = xguards ,
318324 yguards = (yguards is not False ),
325+ zguards = zguards ,
319326 info = info ,
320327 )
321328 if is_fieldperp :
@@ -342,6 +349,7 @@ def getDataFile(i):
342349 # Finished looping over all files
343350 if info :
344351 sys .stdout .write ("\n " )
352+
345353 return BoutArray (data , attributes = var_attributes )
346354
347355
@@ -359,6 +367,7 @@ def _collect_from_single_file(
359367 prefix ,
360368 strict ,
361369 datafile_cache ,
370+ zguards ,
362371):
363372 """
364373 Collect data from a single file
@@ -394,6 +403,11 @@ def _collect_from_single_file(
394403 except KeyError :
395404 myg = 0
396405 print (f"MYG not found, setting to { myg } " )
406+ try :
407+ mzg = f ["MZG" ]
408+ except KeyError :
409+ mzg = 0
410+ print (f"MZG not found, setting to { mzg } " )
397411
398412 if xguards :
399413 nx = f ["nx" ]
@@ -407,7 +421,11 @@ def _collect_from_single_file(
407421 ny = ny + 2 * myg
408422 else :
409423 ny = f ["ny" ]
410- nz = f ["MZ" ]
424+
425+ if zguards :
426+ nz = f ["nz" ] + 2 * mzg
427+ else :
428+ nz = f ["nz" ]
411429 t_array = f .read ("t_array" )
412430 if t_array is None :
413431 nt = 1
@@ -429,6 +447,8 @@ def _collect_from_single_file(
429447 xind = slice (xind .start + mxg , xind .stop + mxg , xind .step )
430448 if not yguards :
431449 yind = slice (yind .start + myg , yind .stop + myg , yind .step )
450+ if not zguards :
451+ zind = slice (zind .start + mzg , zind .stop + mzg , zind .step )
432452
433453 dim_ranges = {"t" : tind , "x" : xind , "y" : yind , "z" : zind }
434454 ranges = [dim_ranges .get (dim , None ) for dim in dimensions ]
@@ -510,6 +530,7 @@ def _collect_from_one_proc(
510530 zind ,
511531 xguards ,
512532 yguards ,
533+ zguards ,
513534 info ,
514535 parallel_read = False ,
515536):
@@ -594,15 +615,20 @@ def _collect_from_one_proc(
594615
595616 nxpe = grid_info ["nxpe" ]
596617 nype = grid_info ["nype" ]
618+ nzpe = grid_info ["nzpe" ]
597619 mxsub = grid_info ["mxsub" ]
598620 mysub = grid_info ["mysub" ]
621+ mzsub = grid_info ["mzsub" ]
599622 mxg = grid_info ["mxg" ]
600623 myg = grid_info ["myg" ]
624+ mzg = grid_info ["mzg" ]
601625 yproc_upper_target = grid_info ["yproc_upper_target" ]
602626
603- # Get X and Y processor indices
604- pe_yind = i // nxpe
605- pe_xind = i % nxpe
627+ # Get processor indices. `grid_info` only has global data, whereas these are
628+ # specific to each file
629+ pe_xind = datafile .read ("PE_XIND" ) or i % nxpe
630+ pe_yind = datafile .read ("PE_YIND" ) or (i // nxpe ) % nype
631+ pe_zind = datafile .read ("PE_ZIND" ) or i // (nxpe * nype )
606632
607633 inrange = True
608634
@@ -624,18 +650,38 @@ def _collect_from_one_proc(
624650 yguards , yind , pe_yind , nype , yproc_upper_target , mysub , myg , inrange
625651 )
626652
653+ is_field2d = dimensions == ("t" , "x" , "y" ) or dimensions == ("x" , "y" )
654+ if is_field2d :
655+ # Field2Ds do not have a z-dimension, so cannot be sliced in z and should
656+ # always be read regardless of the value of zind (so we should not change
657+ # inrange by checking the z-range).
658+ # zstart, zstop, zgstart and zgstop are set only to avoid errors in 'info'
659+ # messages.
660+ zstart = 0
661+ zstop = 1
662+ zgstart = 0
663+ zgstop = 1
664+ else :
665+ zstart , zstop , zgstart , zgstop , inrange = _get_z_range (
666+ zguards , zind , pe_zind , nzpe , mzsub , mzg , inrange
667+ )
668+
627669 if not inrange :
628670 return None , None # Don't need this file
629671
630672 local_dim_slices = {
631673 "t" : tind ,
632674 "x" : slice (xstart , xstop ),
633675 "y" : slice (ystart , ystop ),
634- "z" : zind ,
676+ "z" : slice ( zstart , zstop ) ,
635677 }
636678 local_slices = tuple (local_dim_slices .get (dim , None ) for dim in dimensions )
637679
638- global_dim_slices = {"x" : slice (xgstart , xgstop ), "y" : slice (ygstart , ygstop )}
680+ global_dim_slices = {
681+ "x" : slice (xgstart , xgstop ),
682+ "y" : slice (ygstart , ygstop ),
683+ "z" : slice (zgstart , zgstop ),
684+ }
639685 if parallel_read :
640686 # When reading in parallel, we are always reading into a 4-dimensional shared
641687 # array. Should not reach this function unless we only have dimensions in
@@ -652,7 +698,8 @@ def _collect_from_one_proc(
652698
653699 if info :
654700 print (
655- f"\r Reading from { i } : [{ xstart } -{ xstop - 1 } ][{ ystart } -{ ystop - 1 } ] -> [{ xgstart } -{ xgstop - 1 } ][{ ygstart } -{ ygstop - 1 } ]\n "
701+ f"\r Reading from { i } : [{ xstart } -{ xstop - 1 } ][{ ystart } -{ ystop - 1 } ][{ zstart } -{ zstop - 1 } ] "
702+ f"-> [{ xgstart } -{ xgstop - 1 } ][{ ygstart } -{ ygstop - 1 } ][{ zgstart } -{ zgstop - 1 } ]\n "
656703 )
657704
658705 if is_fieldperp :
@@ -684,8 +731,7 @@ def _fieldperp_from_this(nype, pe_yind, mysub, myg, temp_yindex):
684731
685732def _check_local_range_lower (start , stop , lower_index , inrange ):
686733 """
687- Utility function for _get_x_range and _get_y_range. Checks inner or lower edge of
688- local ranges.
734+ Utility function for `_get_{x,y,z}_range`. Checks inner or lower edge of local ranges.
689735
690736 Parameters
691737 ----------
@@ -916,6 +962,83 @@ def _get_y_range(yguards, yind, pe_yind, nype, yproc_upper_target, mysub, myg, i
916962 return ystart , ystop , ygstart , ygstop , inrange
917963
918964
965+ def _get_z_range (zguards , zind , pe_zind , nzpe , mzsub , mzg , inrange ):
966+ """
967+ Get local ranges of z-indices
968+
969+ Parameters
970+ ----------
971+ zguards : bool
972+ Include z-boundaries?
973+ zind : slice
974+ Global slice to apply to z-dimension
975+ pe_zind : int
976+ z-indez of the processor
977+ nzpe : int
978+ Number of processors in the z-direction
979+ mzsub : int
980+ Number of grid cells (excluding guard cells) in the z-direction on a single
981+ procssor
982+ mzg : int
983+ Number of guard cells in the z-direction
984+ inrange : bool
985+ Does the processor have data to read?
986+
987+ Returns
988+ -------
989+ zstart : int
990+ Local z-index to start reading
991+ zstop : int
992+ Local z-index to stop reading
993+ zgstart : int
994+ Global z-index to start putting data
995+ zgstop : int
996+ Global z-index to stop putting data
997+ inrange : bool
998+ Updated version of inrange - changed to False if this processor has no data to
999+ read
1000+ """
1001+ # Local ranges
1002+ if zguards :
1003+ zstart = zind .start - pe_zind * mzsub
1004+ zstop = zind .stop - pe_zind * mzsub
1005+
1006+ # Check lower z boundary
1007+ if pe_zind == 0 :
1008+ # Keeping inner boundary
1009+ zstart , inrange = _check_local_range_lower (zstart , zstop , 0 , inrange )
1010+ else :
1011+ zstart , inrange = _check_local_range_lower (zstart , zstop , mzg , inrange )
1012+
1013+ # Upper z boundary
1014+ if pe_zind == (nzpe - 1 ):
1015+ # Keeping outer boundary
1016+ zstop , inrange = _check_local_range_upper (
1017+ zstart , zstop , mzsub + 2 * mzg , inrange
1018+ )
1019+ else :
1020+ zstop , inrange = _check_local_range_upper (
1021+ zstart , zstop , mzsub + mzg , inrange
1022+ )
1023+
1024+ else :
1025+ zstart = zind .start - pe_zind * mzsub + mzg
1026+ zstop = zind .stop - pe_zind * mzsub + mzg
1027+
1028+ zstart , inrange = _check_local_range_lower (zstart , zstop , mzg , inrange )
1029+ zstop , inrange = _check_local_range_upper (zstart , zstop , mzsub + mzg , inrange )
1030+
1031+ # Global ranges
1032+ if zguards :
1033+ zgstart = zstart + pe_zind * mzsub - zind .start
1034+ zgstop = zstop + pe_zind * mzsub - zind .start
1035+ else :
1036+ zgstart = zstart + pe_zind * mzsub - mzg - zind .start
1037+ zgstop = zstop + pe_zind * mzsub - mzg - zind .start
1038+
1039+ return zstart , zstop , zgstart , zgstop , inrange
1040+
1041+
9191042def _check_fieldperp_attributes (
9201043 varname ,
9211044 yindex_global ,
@@ -950,7 +1073,17 @@ def _check_fieldperp_attributes(
9501073
9511074
9521075def _get_grid_info (
953- f , * , xguards , yguards , tind , xind , yind , zind , nfiles , all_vars_info = False
1076+ f ,
1077+ * ,
1078+ xguards ,
1079+ yguards ,
1080+ zguards : bool ,
1081+ tind ,
1082+ xind ,
1083+ yind ,
1084+ zind ,
1085+ nfiles ,
1086+ all_vars_info = False ,
9541087):
9551088 """Get the grid info from an open DataFile
9561089
@@ -993,8 +1126,10 @@ def load_and_check(varname):
9931126
9941127 mxg = int (load_and_check ("MXG" ))
9951128 myg = int (load_and_check ("MYG" ))
1129+ mzg = int (f .read ("MZG" ) or 0 )
9961130 mxsub = int (load_and_check ("MXSUB" ))
9971131 mysub = int (load_and_check ("MYSUB" ))
1132+ mzsub = int (f .read ("MZSUB" ) or mz )
9981133 try :
9991134 nxpe = int (f ["NXPE" ])
10001135 except KeyError :
@@ -1006,6 +1141,9 @@ def load_and_check(varname):
10061141 nype = nfiles
10071142 print (f"NYPE not found, setting to { nype } " )
10081143
1144+ # Don't warn, most files won't have this
1145+ nzpe = int (f .get ("NZPE" , 1 ))
1146+
10091147 if "t_array" in f .keys ():
10101148 nt = len (f .read ("t_array" ))
10111149 else :
@@ -1031,7 +1169,12 @@ def load_and_check(varname):
10311169 else :
10321170 ny = mysub * nype
10331171
1034- nz = mz - 1 if version < 3.5 else mz
1172+ if zguards :
1173+ nz = mzsub * nzpe + 2 * mzg
1174+ elif version < 3.5 :
1175+ nz = mz - 1
1176+ else :
1177+ nz = mzsub * nzpe
10351178
10361179 tind = _convert_to_nice_slice (tind , nt , "tind" )
10371180 xind = _convert_to_nice_slice (xind , nx , "xind" )
@@ -1053,13 +1196,16 @@ def load_and_check(varname):
10531196 "mxsub" : mxsub ,
10541197 "myg" : myg ,
10551198 "mysub" : mysub ,
1199+ "mzg" : mzg ,
1200+ "mzsub" : mzsub ,
10561201 "nt" : nt ,
1057- "npes" : nxpe * nype ,
1202+ "npes" : nxpe * nype * nzpe ,
10581203 "nx" : nx ,
10591204 "nxpe" : nxpe ,
10601205 "ny" : ny ,
10611206 "nype" : nype ,
10621207 "nz" : nz ,
1208+ "nzpe" : nzpe ,
10631209 "sizes" : sizes ,
10641210 "varNames" : varNames ,
10651211 "yproc_upper_target" : yproc_upper_target ,
0 commit comments