Skip to content

Commit 1bb7826

Browse files
FabioLuporiniJDBetteridge
authored andcommitted
compiler: Fix CustomTopology
1 parent 463f204 commit 1bb7826

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

devito/mpi/distributed.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def nprocs_local(self):
261261

262262
@property
263263
def topology(self):
264-
return DimensionTuple(*self._topology, getters=self.dimensions)
264+
return self._topology
265265

266266
@property
267267
def topology_logical(self):
@@ -353,7 +353,9 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None):
353353
self._topology = compute_dims(self._input_comm.size, len(shape))
354354
else:
355355
# A custom topology may contain integers or the wildcard '*'
356-
self._topology = CustomTopology(topology, self._input_comm)
356+
self._topology = CustomTopology(
357+
topology, self._input_comm, getters=dimensions
358+
)
357359

358360
if self._input_comm is not input_comm:
359361
# By default, Devito arranges processes into a cartesian topology.
@@ -896,7 +898,7 @@ def _arg_values(self, *args, **kwargs):
896898
return self._arg_defaults()
897899

898900

899-
class CustomTopology(tuple):
901+
class CustomTopology(DimensionTuple):
900902

901903
"""
902904
The CustomTopology class provides a mechanism to describe parametric domain
@@ -954,7 +956,7 @@ class CustomTopology(tuple):
954956
'xy': ('*', '*', 1),
955957
}
956958

957-
def __new__(cls, items, input_comm):
959+
def __new__(cls, items, input_comm, **kwargs):
958960
# Keep track of nstars and already defined decompositions
959961
nstars = items.count('*')
960962

@@ -992,11 +994,15 @@ def __new__(cls, items, input_comm):
992994
# Final check that topology matches the communicator size
993995
assert np.prod(processed) == input_comm.size
994996

995-
obj = super().__new__(cls, processed)
997+
obj = super().__new__(cls, *processed, **kwargs)
996998
obj.logical = items
997999

9981000
return obj
9991001

1002+
def __repr__(self):
1003+
return (f"CustomTopology(logical={self.logical}, "
1004+
f"physical={super().__repr__()})")
1005+
10001006

10011007
def compute_dims(nprocs, ndim):
10021008
# We don't do anything clever here. In fact, we do something very basic --

0 commit comments

Comments
 (0)