@@ -491,9 +491,10 @@ def test_tau_method2D(variant, nz, nx, bc_val, bc=-1, useMPI=False, plotting=Fal
491491 Dxx = helper .get_differentiation_matrix (axes = (0 ,), p = 2 )
492492
493493 # generate operator
494- _A = helper .get_empty_operator_matrix ()
495- helper .add_equation_lhs (_A , 'u' , {'u' : Dz - Dxx * 1e-1 - Dx })
496- A = helper .convert_operator_matrix_to_operator (_A )
494+ diag = True
495+ _A = helper .get_empty_operator_matrix (diag = diag )
496+ helper .add_equation_lhs (_A , 'u' , {'u' : Dz - Dxx * 1e-1 - Dx }, diag = diag )
497+ A = helper .convert_operator_matrix_to_operator (_A , diag = diag )
497498
498499 # prepare system to solve
499500 A = helper .put_BCs_in_matrix (A )
@@ -608,6 +609,34 @@ def function():
608609 assert track [0 ] == 0 , "possible memory leak with the @cache"
609610
610611
612+ @pytest .mark .base
613+ def test_block_diagonal_operators (N = 16 ):
614+ from pySDC .helpers .spectral_helper import SpectralHelper
615+ import numpy as np
616+
617+ helper = SpectralHelper (comm = None , debug = True )
618+ helper .add_axis ('fft' , N = N )
619+ helper .add_axis ('cheby' , N = N )
620+ helper .add_component (['u' , 'v' ])
621+ helper .setup_fft ()
622+
623+ # generate matrices
624+ Dz = helper .get_differentiation_matrix (axes = (1 ,))
625+ Dx = helper .get_differentiation_matrix (axes = (0 ,))
626+
627+ def get_operator (diag ):
628+ _A = helper .get_empty_operator_matrix (diag = diag )
629+ helper .add_equation_lhs (_A , 'u' , {'u' : Dx }, diag = diag )
630+ helper .add_equation_lhs (_A , 'v' , {'v' : Dz }, diag = diag )
631+ return helper .convert_operator_matrix_to_operator (_A , diag = diag )
632+
633+ AD = get_operator (True )
634+ A = get_operator (False )
635+
636+ assert np .allclose (A .toarray (), AD .toarray ()), 'Operators don\' t match'
637+ assert A .data .nbytes > AD .data .nbytes , 'Block diagonal operator did not conserve memory over general operator'
638+
639+
611640if __name__ == '__main__' :
612641 str_to_bool = lambda me : False if me == 'False' else True
613642 str_to_tuple = lambda arg : tuple (int (me ) for me in arg .split (',' ))
@@ -642,9 +671,10 @@ def function():
642671 # test_differentiation_matrix2D(2**5, 2**5, 'T2U', bx='fft', bz='fft', axes=(-2, -1))
643672 # test_matrix1D(4, 'cheby', 'int')
644673 # test_tau_method(-1, 8, 99, kind='Dirichlet')
645- test_tau_method2D ('T2U' , 2 ** 8 , 2 ** 8 , - 2 , plotting = True )
674+ # test_tau_method2D('T2U', 2**8, 2**8, -2, plotting=True)
646675 # test_filter(6, 6, (0,))
647676 # _test_transform_dealias('fft', 'cheby', (-1, -2))
677+ test_block_diagonal_operators ()
648678 else :
649679 raise NotImplementedError
650680 print ('done' )
0 commit comments