Skip to content

Commit 221c555

Browse files
committed
unroll change
1 parent 3e98566 commit 221c555

3 files changed

Lines changed: 5 additions & 14 deletions

File tree

pylops_mpi/StackedLinearOperator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ def matvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[D
7070

7171
M, N = self.shape
7272
if isinstance(x, StackedDistributedArray):
73-
stacked_shape = (
74-
np.sum(np.fromiter((a.global_shape[0] for a in x.distarrays), dtype=np.int64)),
75-
)
73+
stacked_shape = (np.sum([a.global_shape for a in x.distarrays]), )
7674
if stacked_shape != (N, ):
7775
raise ValueError("dimension mismatch")
7876
if isinstance(x, DistributedArray) and x.global_shape != (N,):
@@ -105,9 +103,7 @@ def rmatvec(self, x: Union[DistributedArray, StackedDistributedArray]) -> Union[
105103

106104
M, N = self.shape
107105
if isinstance(x, StackedDistributedArray):
108-
stacked_shape = (
109-
np.sum(np.fromiter((a.global_shape[0] for a in x.distarrays), dtype=np.int64)),
110-
)
106+
stacked_shape = (np.sum([a.global_shape for a in x.distarrays]), )
111107
if stacked_shape != (M, ):
112108
raise ValueError("dimension mismatch")
113109
if isinstance(x, DistributedArray) and x.global_shape != (M,):

pylops_mpi/basicoperators/BlockDiag.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,8 @@ def __init__(self, ops: Sequence[MPILinearOperator], base_comm: MPI.Comm = MPI.C
170170
dtype: Optional[DTypeLike] = None):
171171
self.ops = ops
172172
dtype = _get_dtype(self.ops) if dtype is None else np.dtype(dtype)
173-
shape = (
174-
int(np.sum(np.fromiter((op.shape[0] for op in ops), dtype=np.int64))),
175-
int(np.sum(np.fromiter((op.shape[1] for op in ops), dtype=np.int64))),
176-
)
173+
shape = (int(np.sum(op.shape[0] for op in ops)),
174+
int(np.sum(op.shape[1] for op in ops)))
177175
super().__init__(shape=shape, dtype=dtype, base_comm=base_comm)
178176

179177
def _matvec(self, x: StackedDistributedArray) -> StackedDistributedArray:

pylops_mpi/basicoperators/VStack.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,7 @@ def __init__(self, ops: Sequence[MPILinearOperator],
183183
self.ops = ops
184184
if len(set(op.shape[1] for op in ops)) > 1:
185185
raise ValueError("Operators have different number of columns")
186-
shape = (
187-
int(np.sum(np.fromiter((op.shape[0] for op in ops), dtype=np.int64))),
188-
ops[0].shape[1],
189-
)
186+
shape = (int(np.sum(op.shape[0] for op in ops)), ops[0].shape[1])
190187
dtype = _get_dtype(self.ops) if dtype is None else np.dtype(dtype)
191188
super().__init__(shape=shape, dtype=dtype, base_comm=base_comm)
192189

0 commit comments

Comments
 (0)