Skip to content

Commit 3e95927

Browse files
committed
Masks for block matrix
1 parent 6e95f87 commit 3e95927

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

cfpq_matrix/block/block_matrix.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@ def optimize_similarly(self, other: Matrix) -> "OptimizedMatrix":
2424
)
2525

2626
def to_mask(self):
27-
if self.block_matrix_space.is_single_cell(self.shape):
28-
return self.base.to_mask()
29-
else:
30-
return None
27+
return self.base.to_mask()
3128

3229

3330
class CellBlockMatrix(BlockMatrix):
@@ -38,6 +35,8 @@ def __init__(self, base: OptimizedMatrix, block_matrix_space: BlockMatrixSpace):
3835
def mxm(self, other: Matrix, op: Semiring, mask:Matrix, swap_operands: bool = False) -> Matrix:
3936
if self.block_matrix_space.is_single_cell(other.shape):
4037
return self.base.mxm(other, op, mask, swap_operands=swap_operands)
38+
if not mask is None:
39+
mask = self.block_matrix_space.repeat_into_hyper_column(mask)
4140
return self.base.mxm(
4241
self.block_matrix_space.hyper_rotate(
4342
other,
@@ -92,6 +91,8 @@ def _force_init_orientation(
9291

9392
def mxm(self, other: Matrix, op: Semiring, mask:Matrix, swap_operands: bool = False) -> Matrix:
9493
if self.block_matrix_space.is_single_cell(other.shape):
94+
if not mask is None:
95+
mask = self.block_matrix_space.hyper_rotate(self.block_matrix_space.repeat_into_hyper_column(mask),BlockMatrixOrientation.HORIZONTAL)
9596
return self._force_init_orientation(
9697
BlockMatrixOrientation.HORIZONTAL
9798
if swap_operands

0 commit comments

Comments
 (0)