Skip to content

Commit 7838cd3

Browse files
committed
restart.resize: Disable multiprocessing if maxProc=1
Encountering strange errors when using multiprocessing, that disappear if run in serial.
1 parent cd0bc40 commit 7838cd3

1 file changed

Lines changed: 45 additions & 28 deletions

File tree

boutdata/restart.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def resize3DField(var, data, coordsAndSizesTuple, method, mute):
6363
print(
6464
" Resizing "
6565
+ var
66-
+ " to (nx,ny,nz) = ({},{},{})".format(newNx, newNy, newNz)
66+
+ " from (nx,ny,nz) = ({},{},{})".format(*data.shape)
67+
+ " to ({},{},{})".format(newNx, newNy, newNz)
6768
)
6869

6970
# Make the regular grid function (see examples in
@@ -108,6 +109,9 @@ def resize(
108109
NOTE: Can't overwrite
109110
WARNING: Currently only implemented with uniform BOUT++ grid
110111
112+
If errors occur, try running with maxProc=1. That will disable
113+
multiprocessing so will be slow.
114+
111115
Parameters
112116
----------
113117
newNx, newNy, newNz : int
@@ -125,7 +129,8 @@ def resize(
125129
method : {'linear', 'nearest'}, optional
126130
What interpolation method to be used
127131
maxProc : {None, int}, optional
128-
Limits maximum processors to use when interpolating if set
132+
Limits maximum processors to use when interpolating if set.
133+
Set to 1 to disable multiprocessing.
129134
mute : bool, optional
130135
Whether or not output should be printed from this function
131136
@@ -204,9 +209,10 @@ def resize(
204209
zCoordNew = (np.arange(newNz) + zshift) * newDz
205210

206211
# Make a pool of workers
207-
pool = multiprocessing.Pool(maxProc)
208-
# List of jobs and results
209-
jobs = []
212+
if maxProc != 1:
213+
pool = multiprocessing.Pool(maxProc)
214+
# List of jobs and results
215+
jobs = []
210216
# Pack input to resize3DField together
211217
coordsAndSizesTuple = (
212218
xCoordOld,
@@ -228,19 +234,29 @@ def resize(
228234

229235
# Find 3D variables
230236
if old.ndims(var) == 3:
231-
# Asynchronous call (locks first at .get())
232-
jobs.append(
233-
pool.apply_async(
234-
resize3DField,
235-
args=(
236-
var,
237-
data,
238-
coordsAndSizesTuple,
239-
method,
240-
mute,
241-
),
237+
if maxProc != 1:
238+
# Asynchronous call (locks first at .get())
239+
jobs.append(
240+
pool.apply_async(
241+
resize3DField,
242+
args=(
243+
var,
244+
data,
245+
coordsAndSizesTuple,
246+
method,
247+
mute,
248+
),
249+
)
242250
)
243-
)
251+
else:
252+
# Synchronous call. Easier for debugging
253+
_, newData = resize3DField(
254+
var, data, coordsAndSizesTuple, method, mute
255+
)
256+
newData = BoutArray(newData, attributes=attributes)
257+
if not (mute):
258+
print("Writing " + var)
259+
new.write(var, newData)
244260

245261
else:
246262
if not (mute):
@@ -250,17 +266,18 @@ def resize(
250266
print("Writing " + var)
251267
new.write(var, newData)
252268

253-
for job in jobs:
254-
var, newData = job.get()
255-
newData = BoutArray(newData, attributes=attributes)
256-
if not (mute):
257-
print("Writing " + var)
258-
new.write(var, newData)
259-
260-
# Close the pool of workers
261-
pool.close()
262-
# Wait for all processes to finish
263-
pool.join()
269+
if maxProc != 1:
270+
for job in jobs:
271+
var, newData = job.get()
272+
newData = BoutArray(newData, attributes=attributes)
273+
if not (mute):
274+
print("Writing " + var)
275+
new.write(var, newData)
276+
277+
# Close the pool of workers
278+
pool.close()
279+
# Wait for all processes to finish
280+
pool.join()
264281

265282
return True
266283

0 commit comments

Comments
 (0)