@@ -225,6 +225,7 @@ def initField(self, initPath, seed):
225225 def runSimulation (cls , runDir , tEnd , baseDt , tBeg = 0 , logEvery = 100 ,
226226 dtWrite = None , writeVort = False , writeTau = False ,
227227 timeScheme = "RK443" , timeParallel = False , groupTimeProcs = False ,
228+ writeDecomposition = False ,
228229 ** pParams ):
229230
230231 cls .log (f"RBC simulation in { runDir } " )
@@ -243,7 +244,7 @@ def runSimulation(cls, runDir, tEnd, baseDt, tBeg=0, logEvery=100,
243244
244245 if timeParallel :
245246 assert timeScheme == "SDC" , "need timeScheme=SDC for timeParallel"
246- _ , sComm , _ = SDCIMEX_MPI .initSpaceTimeComms (groupTime = groupTimeProcs )
247+ tComm , sComm , gComm = SDCIMEX_MPI .initSpaceTimeComms (groupTime = groupTimeProcs )
247248 pParams .update (sComm = sComm )
248249 if timeParallel == "MPI" :
249250 TimeStepper = SDCIMEX_MPI
@@ -255,6 +256,30 @@ def runSimulation(cls, runDir, tEnd, baseDt, tBeg=0, logEvery=100,
255256
256257 p = cls (** pParams )
257258
259+ if writeDecomposition :
260+ if MPI_RANK == 0 :
261+ with open ("{runDir}/distrib.txt" , "r" ) as f :
262+ f .write ("Parallel distribution on compute cores\n " )
263+ f .write (f" -- space parallelization on { p .sComm .Get_size ()} procs\n " )
264+ if timeParallel :
265+ f .write (f" -- time parallelization on { tComm .Get_size ()} procs\n " )
266+ f .write (f" -- global parallelization on { gComm .Get_size ()} procs\n " )
267+ coords = [p .grids [axis ] for axis in p .axes ]
268+ labels = ["x" , "y" , "z" ]
269+ COMM_WORLD .Barrier ()
270+ sleep (0.0001 * MPI_RANK )
271+ with open ("{runDir}/distrib.txt" , "a" ) as f :
272+ if timeParallel :
273+ out = f"P{ gComm .Get_rank ()} -S{ sComm .Get_rank ()} -T{ tComm .Get_rank ()} :\n "
274+ else :
275+ out = f"P{ MPI_RANK } :\n "
276+ out += f" -- cpu { os .sched_getaffinity (0 )} on { socket .gethostname ()} \n "
277+ out += "\n " .join (
278+ f" -- { d } { c .shape } : [{ c .min (initial = np .inf )} , { c .max (initial = - np .inf )} ]"
279+ for d , c in zip (labels , coords )
280+ )
281+ COMM_WORLD .Barrier ()
282+
258283 dt = baseDt / p .resFactor
259284 nSteps = round (float (tEnd - tBeg )/ dt , ndigits = 3 )
260285 if float (tEnd - tBeg ) != round (nSteps * dt , ndigits = 3 ):
0 commit comments