@@ -24,10 +24,7 @@ def optimize_similarly(self, other: Matrix) -> "OptimizedMatrix":
2424 )
2525
2626 def to_mask (self ):
27- if self .block_matrix_space .is_single_cell (self .shape ):
28- return self .base .to_mask ()
29- else :
30- return None
27+ return self .base .to_mask ()
3128
3229
3330class CellBlockMatrix (BlockMatrix ):
@@ -38,6 +35,8 @@ def __init__(self, base: OptimizedMatrix, block_matrix_space: BlockMatrixSpace):
3835 def mxm (self , other : Matrix , op : Semiring , mask :Matrix , swap_operands : bool = False ) -> Matrix :
3936 if self .block_matrix_space .is_single_cell (other .shape ):
4037 return self .base .mxm (other , op , mask , swap_operands = swap_operands )
38+ if not mask is None :
39+ mask = self .block_matrix_space .repeat_into_hyper_column (mask )
4140 return self .base .mxm (
4241 self .block_matrix_space .hyper_rotate (
4342 other ,
@@ -92,6 +91,8 @@ def _force_init_orientation(
9291
9392 def mxm (self , other : Matrix , op : Semiring , mask :Matrix , swap_operands : bool = False ) -> Matrix :
9493 if self .block_matrix_space .is_single_cell (other .shape ):
94+ if not mask is None :
95+ mask = self .block_matrix_space .hyper_rotate (self .block_matrix_space .repeat_into_hyper_column (mask ),BlockMatrixOrientation .HORIZONTAL )
9596 return self ._force_init_orientation (
9697 BlockMatrixOrientation .HORIZONTAL
9798 if swap_operands
0 commit comments