Skip to content

Commit eec591d

Browse files
authored
Merge pull request #1111 from jalvesz/sparse_sym
fix (sparse): ensure proper assembly when filling in symmetric matrices with blocks
2 parents 843ffb0 + 303403b commit eec591d

3 files changed

Lines changed: 141 additions & 23 deletions

File tree

src/sparse/stdlib_sparse_constants.fypp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
33
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
44
module stdlib_sparse_constants
5-
use stdlib_kinds, only: int8, int16, int32, int64, sp, dp, xdp, qp
5+
use stdlib_kinds, only: int8, int16, int32, int64, sp, dp, xdp, qp, c_bool
66
use stdlib_constants
77
implicit none
88
public

src/sparse/stdlib_sparse_kinds.fypp

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,17 @@ contains
330330
! data accessors
331331
!==================================================================
332332

333+
logical(c_bool) elemental function skip(sym,row,col)
334+
integer(ilp), intent(in) :: sym, row, col
335+
skip = (sym == sparse_lower .and. row < col) .or. (sym == sparse_upper .and. row > col)
336+
end function
337+
333338
#:for k1, t1, s1 in (KINDS_TYPES)
334339
pure ${t1}$ function at_value_coo_${s1}$(self,ik,jk) result(val)
335340
class(COO_${s1}$_type), intent(in) :: self
336341
integer(ilp), intent(in) :: ik, jk
337342
integer(ilp) :: k, ik_, jk_
338-
logical :: transpose
343+
logical(c_bool) :: transpose
339344
! naive implementation
340345
if( (ik<1 .or. ik>self%nrows) .or. (jk<1 .or. jk>self%ncols) ) then
341346
val = ieee_value( 0._${k1}$ , ieee_quiet_nan)
@@ -373,14 +378,18 @@ contains
373378
class(COO_${s1}$_type), intent(inout) :: self
374379
${t1}$, intent(in) :: val(:,:)
375380
integer(ilp), intent(in) :: ik(:), jk(:)
376-
integer(ilp) :: k, i, j
381+
integer(ilp) :: k, i, j, row, col
377382
! naive implementation
378383
do k = 1, self%nnz
379384
do i = 1, size(ik)
380-
if( ik(i) /= self%index(1,k) ) cycle
385+
row = ik(i)
386+
if( row /= self%index(1,k) ) cycle
381387
do j = 1, size(jk)
382-
if( jk(j) /= self%index(2,k) ) cycle
388+
col = jk(j)
389+
if( skip(self%storage,row,col) ) cycle
390+
if( col /= self%index(2,k) ) cycle
383391
self%data(k) = self%data(k) + val(i,j)
392+
exit
384393
end do
385394
end do
386395
end do
@@ -393,7 +402,7 @@ contains
393402
class(CSR_${s1}$_type), intent(in) :: self
394403
integer(ilp), intent(in) :: ik, jk
395404
integer(ilp) :: k, ik_, jk_
396-
logical :: transpose
405+
logical(c_bool) :: transpose
397406
! naive implementation
398407
if( (ik<1 .or. ik>self%nrows) .or. (jk<1 .or. jk>self%ncols) ) then
399408
val = ieee_value( 0._${k1}$ , ieee_quiet_nan)
@@ -431,13 +440,17 @@ contains
431440
class(CSR_${s1}$_type), intent(inout) :: self
432441
${t1}$, intent(in) :: val(:,:)
433442
integer(ilp), intent(in) :: ik(:), jk(:)
434-
integer(ilp) :: k, i, j
443+
integer(ilp) :: k, i, j, row, col
435444
! naive implementation
436445
do i = 1, size(ik)
437-
do k = self%rowptr(ik(i)), self%rowptr(ik(i)+1)-1
446+
row = ik(i)
447+
do k = self%rowptr(row), self%rowptr(row+1)-1
438448
do j = 1, size(jk)
439-
if( jk(j) == self%col(k) ) then
449+
col = jk(j)
450+
if( skip(self%storage,row,col) ) cycle
451+
if( col == self%col(k) ) then
440452
self%data(k) = self%data(k) + val(i,j)
453+
exit
441454
end if
442455
end do
443456
end do
@@ -451,7 +464,7 @@ contains
451464
class(CSC_${s1}$_type), intent(in) :: self
452465
integer(ilp), intent(in) :: ik, jk
453466
integer(ilp) :: k, ik_, jk_
454-
logical :: transpose
467+
logical(c_bool) :: transpose
455468
! naive implementation
456469
if( (ik<1 .or. ik>self%nrows) .or. (jk<1 .or. jk>self%ncols) ) then
457470
val = ieee_value( 0._${k1}$ , ieee_quiet_nan)
@@ -489,13 +502,17 @@ contains
489502
class(CSC_${s1}$_type), intent(inout) :: self
490503
${t1}$, intent(in) :: val(:,:)
491504
integer(ilp), intent(in) :: ik(:), jk(:)
492-
integer(ilp) :: k, i, j
505+
integer(ilp) :: k, i, j, row, col
493506
! naive implementation
494507
do j = 1, size(jk)
495-
do k = self%colptr(jk(j)), self%colptr(jk(j)+1)-1
508+
col = jk(j)
509+
do k = self%colptr(col), self%colptr(col+1)-1
496510
do i = 1, size(ik)
497-
if( ik(i) == self%row(k) ) then
511+
row = ik(i)
512+
if( skip(self%storage,row,col) ) cycle
513+
if( row == self%row(k) ) then
498514
self%data(k) = self%data(k) + val(i,j)
515+
exit
499516
end if
500517
end do
501518
end do
@@ -509,7 +526,7 @@ contains
509526
class(ELL_${s1}$_type), intent(in) :: self
510527
integer(ilp), intent(in) :: ik, jk
511528
integer(ilp) :: k, ik_, jk_
512-
logical :: transpose
529+
logical(c_bool) :: transpose
513530
! naive implementation
514531
if( (ik<1 .or. ik>self%nrows) .or. (jk<1 .or. jk>self%ncols) ) then
515532
val = ieee_value( 0._${k1}$ , ieee_quiet_nan)
@@ -547,13 +564,17 @@ contains
547564
class(ELL_${s1}$_type), intent(inout) :: self
548565
${t1}$, intent(in) :: val(:,:)
549566
integer(ilp), intent(in) :: ik(:), jk(:)
550-
integer(ilp) :: k, i, j
567+
integer(ilp) :: k, i, j, row, col
551568
! naive implementation
552569
do k = 1 , self%K
553570
do j = 1, size(jk)
571+
col = jk(j)
554572
do i = 1, size(ik)
555-
if( jk(j) == self%index(ik(i),k) ) then
556-
self%data(ik(i),k) = self%data(ik(i),k) + val(i,j)
573+
row = ik(i)
574+
if( skip(self%storage,row,col) ) cycle
575+
if( col == self%index(row,k) ) then
576+
self%data(row,k) = self%data(row,k) + val(i,j)
577+
exit
557578
end if
558579
end do
559580
end do
@@ -567,7 +588,7 @@ contains
567588
class(SELLC_${s1}$_type), intent(in) :: self
568589
integer(ilp), intent(in) :: ik, jk
569590
integer(ilp) :: k, ik_, jk_, idx
570-
logical :: transpose
591+
logical(c_bool) :: transpose
571592
! naive implementation
572593
if( (ik<1 .or. ik>self%nrows) .or. (jk<1 .or. jk>self%ncols) ) then
573594
val = ieee_value( 0._${k1}$ , ieee_quiet_nan)
@@ -608,14 +629,18 @@ contains
608629
class(SELLC_${s1}$_type), intent(inout) :: self
609630
${t1}$, intent(in) :: val(:,:)
610631
integer(ilp), intent(in) :: ik(:), jk(:)
611-
integer(ilp) :: k, i, j, idx
632+
integer(ilp) :: k, i, j, idx, row, col
612633
! naive implementation
613634
do k = 1 , self%chunk_size
614635
do j = 1, size(jk)
636+
col = jk(j)
615637
do i = 1, size(ik)
616-
idx = self%rowptr((ik(i) - 1)/self%chunk_size + 1)
617-
if( jk(j) == self%col(k,idx) ) then
638+
row = ik(i)
639+
idx = self%rowptr((row - 1)/self%chunk_size + 1)
640+
if( skip(self%storage,row,col) ) cycle
641+
if( col == self%col(k,idx) ) then
618642
self%data(k,idx) = self%data(k,idx) + val(i,j)
643+
exit
619644
end if
620645
end do
621646
end do

test/linalg/test_linalg_sparse.fypp

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ contains
2929
new_unittest('symmetries', test_symmetries), &
3030
new_unittest('diagonal', test_diagonal), &
3131
new_unittest('add_get_values', test_add_get_values), &
32-
new_unittest('sparse_operators', test_sparse_operators) &
32+
new_unittest('sparse_operators', test_sparse_operators), &
33+
new_unittest('add_block_symmetric_skip', test_add_block_symmetric_skip) &
3334
]
3435
end subroutine
3536

@@ -373,7 +374,6 @@ contains
373374

374375
call check(error, all(CSR%data == COO%data) )
375376
if (allocated(error)) return
376-
377377
err = 0._wp
378378
do i = 1, 5
379379
do j = 1, 5
@@ -485,8 +485,101 @@ contains
485485
end block
486486
#:endfor
487487
#:endfor
488+
488489
end subroutine
489490

491+
subroutine test_add_block_symmetric_skip(error)
492+
!> Error handling
493+
type(error_type), allocatable, intent(out) :: error
494+
#:for k1, t1, s1 in (KINDS_TYPES)
495+
block
496+
integer, parameter :: wp = ${k1}$
497+
integer :: connectivity(3,3)
498+
499+
real(wp) :: dense(5,5), dense_low(5,5), mat(3,3)
500+
type(COO_${s1}$_type) :: COO_full, COO_low
501+
type(CSR_${s1}$_type) :: CSR_full, CSR_low
502+
type(CSC_${s1}$_type) :: CSC_full, CSC_low
503+
real(wp) :: x(5), y(5), y_ref(5)
504+
${t1}$:: err
505+
integer :: i, j, locdof(3)
506+
507+
connectivity(1:3,1) = [1,2,3]
508+
connectivity(1:3,2) = [2,3,4]
509+
connectivity(1:3,3) = [3,4,5]
510+
511+
mat(:,1) = [1,2,3]
512+
mat(:,2) = [2,1,4]
513+
mat(:,3) = [3,4,1]
514+
515+
dense = 0._wp
516+
do i = 1, 3
517+
locdof(1:3) = connectivity(1:3,i)
518+
dense(locdof,locdof) = dense(locdof,locdof) + mat
519+
end do
520+
521+
call dense2coo(dense,COO_full)
522+
call coo2csr(COO_full,CSR_full)
523+
call coo2csc(COO_full,CSC_full)
524+
dense_low = dense
525+
do i = 1, 5
526+
do j = i+1, 5
527+
dense_low(i,j) = 0._wp
528+
end do
529+
end do
530+
call dense2coo(dense_low,COO_low)
531+
COO_low%storage = sparse_lower
532+
call coo2csr(COO_low,CSR_low)
533+
call coo2csc(COO_low,CSC_low)
534+
535+
COO_full%data = 0._wp
536+
COO_low%data = 0._wp
537+
CSR_full%data = 0._wp
538+
CSR_low%data = 0._wp
539+
CSC_full%data = 0._wp
540+
CSC_low%data = 0._wp
541+
do i = 1, 3
542+
locdof(1:3) = connectivity(1:3,i)
543+
call COO_full%add(locdof,locdof,mat)
544+
call COO_low%add(locdof,locdof,mat)
545+
call CSR_full%add(locdof,locdof,mat)
546+
call CSR_low%add(locdof,locdof,mat)
547+
call CSC_full%add(locdof,locdof,mat)
548+
call CSC_low%add(locdof,locdof,mat)
549+
end do
550+
551+
call check(error, all(CSR_full%data == COO_full%data) , "error in full CSR ${s1}$ data" )
552+
if (allocated(error)) return
553+
554+
call check(error, all(CSR_low%data == COO_low%data) , "error in low CSR ${s1}$ data" )
555+
if (allocated(error)) return
556+
557+
x = 1._wp
558+
y_ref = matmul(dense,x)
559+
560+
y = 0._wp
561+
call spmv( CSR_full, x, y )
562+
call check(error, all(y == y_ref) , "error in full CSR ${s1}$ spmv" )
563+
if (allocated(error)) return
564+
565+
y = 0._wp
566+
call spmv( CSR_low, x, y )
567+
call check(error, all(y == y_ref) , "error in low CSR ${s1}$ spmv" )
568+
if (allocated(error)) return
569+
570+
y = 0._wp
571+
call spmv( CSC_full, x, y )
572+
call check(error, all(y == y_ref) , "error in full CSC ${s1}$ spmv" )
573+
if (allocated(error)) return
574+
575+
y = 0._wp
576+
call spmv( CSC_low, x, y )
577+
call check(error, all(y == y_ref) , "error in low CSC ${s1}$ spmv" )
578+
end block
579+
#:endfor
580+
581+
end subroutine
582+
490583
end module
491584

492585

0 commit comments

Comments
 (0)