@@ -21,11 +21,19 @@ function MatrixAlgebraKit.default_svd_algorithm(A::AbstractBlockSparseMatrix; kw
2121 return BlockPermutedDiagonalAlgorithm (alg)
2222end
2323
24- # TODO : this should be replaced with a more general similar function that can handle setting
25- # the blocktype and element type - something like S = similar(A, BlockType(...))
26- function _similar_S (A:: AbstractBlockSparseMatrix , s_axis)
24+ function similar_output (
25+ :: typeof (svd_compact!),
26+ A,
27+ s_axis:: AbstractUnitRange ,
28+ alg:: MatrixAlgebraKit.AbstractAlgorithm ,
29+ )
30+ U = similar (A, axes (A, 1 ), s_axis)
2731 T = real (eltype (A))
28- return BlockSparseArray {T,2,Diagonal{T,Vector{T}}} (undef, (s_axis, s_axis))
32+ # TODO : this should be replaced with a more general similar function that can handle setting
33+ # the blocktype and element type - something like S = similar(A, BlockType(...))
34+ S = BlockSparseMatrix {T,Diagonal{T,Vector{T}}} (undef, (s_axis, s_axis))
35+ Vt = similar (A, s_axis, axes (A, 2 ))
36+ return U, S, Vt
2937end
3038
3139function MatrixAlgebraKit. initialize_output (
@@ -34,33 +42,29 @@ function MatrixAlgebraKit.initialize_output(
3442 bm, bn = blocksize (A)
3543 bmn = min (bm, bn)
3644
37- brows = blocklengths (axes (A, 1 ))
38- bcols = blocklengths (axes (A, 2 ))
39- slengths = Vector {Int} (undef , bmn)
45+ brows = eachblockaxis (axes (A, 1 ))
46+ bcols = eachblockaxis (axes (A, 2 ))
47+ s_axes = similar (brows , bmn)
4048
4149 # fill in values for blocks that are present
4250 bIs = collect (eachblockstoredindex (A))
4351 browIs = Int .(first .(Tuple .(bIs)))
4452 bcolIs = Int .(last .(Tuple .(bIs)))
4553 for bI in eachblockstoredindex (A)
4654 row, col = Int .(Tuple (bI))
47- nrows = brows[row]
48- ncols = bcols[col]
49- slengths[col] = min (nrows, ncols)
55+ s_axes[col] = argmin (length, (brows[row], bcols[col]))
5056 end
5157
5258 # fill in values for blocks that aren't present, pairing them in order of occurence
5359 # this is a convention, which at least gives the expected results for blockdiagonal
5460 emptyrows = setdiff (1 : bm, browIs)
5561 emptycols = setdiff (1 : bn, bcolIs)
5662 for (row, col) in zip (emptyrows, emptycols)
57- slengths [col] = min ( brows[row], bcols[col])
63+ s_axes [col] = argmin (length, ( brows[row], bcols[col]) )
5864 end
5965
60- s_axis = blockedrange (slengths)
61- U = similar (A, axes (A, 1 ), s_axis)
62- S = _similar_S (A, s_axis)
63- Vt = similar (A, s_axis, axes (A, 2 ))
66+ s_axis = mortar_axis (s_axes)
67+ U, S, Vt = similar_output (svd_compact!, A, s_axis, alg)
6468
6569 # allocate output
6670 for bI in eachblockstoredindex (A)
@@ -79,40 +83,46 @@ function MatrixAlgebraKit.initialize_output(
7983 return U, S, Vt
8084end
8185
86+ function similar_output (
87+ :: typeof (svd_full!), A, s_axis:: AbstractUnitRange , alg:: MatrixAlgebraKit.AbstractAlgorithm
88+ )
89+ U = similar (A, axes (A, 1 ), s_axis)
90+ T = real (eltype (A))
91+ S = similar (A, T, (s_axis, axes (A, 2 )))
92+ Vt = similar (A, axes (A, 2 ), axes (A, 2 ))
93+ return U, S, Vt
94+ end
95+
8296function MatrixAlgebraKit. initialize_output (
8397 :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , alg:: BlockPermutedDiagonalAlgorithm
8498)
8599 bm, bn = blocksize (A)
86100
87- brows = blocklengths (axes (A, 1 ))
88- slengths = copy (brows)
101+ brows = eachblockaxis (axes (A, 1 ))
102+ s_axes = similar (brows)
89103
90104 # fill in values for blocks that are present
91105 bIs = collect (eachblockstoredindex (A))
92106 browIs = Int .(first .(Tuple .(bIs)))
93107 bcolIs = Int .(last .(Tuple .(bIs)))
94108 for bI in eachblockstoredindex (A)
95109 row, col = Int .(Tuple (bI))
96- nrows = brows[row]
97- slengths[col] = nrows
110+ s_axes[col] = brows[row]
98111 end
99112
100113 # fill in values for blocks that aren't present, pairing them in order of occurence
101114 # this is a convention, which at least gives the expected results for blockdiagonal
102115 emptyrows = setdiff (1 : bm, browIs)
103116 emptycols = setdiff (1 : bn, bcolIs)
104117 for (row, col) in zip (emptyrows, emptycols)
105- slengths [col] = brows[row]
118+ s_axes [col] = brows[row]
106119 end
107120 for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
108- slengths [bn + i] = brows[emptyrows[k]]
121+ s_axes [bn + i] = brows[emptyrows[k]]
109122 end
110123
111- s_axis = blockedrange (slengths)
112- U = similar (A, axes (A, 1 ), s_axis)
113- Tr = real (eltype (A))
114- S = similar (A, Tr, (s_axis, axes (A, 2 )))
115- Vt = similar (A, axes (A, 2 ), axes (A, 2 ))
124+ s_axis = mortar_axis (s_axes)
125+ U, S, Vt = similar_output (svd_full!, A, s_axis, alg)
116126
117127 # allocate output
118128 for bI in eachblockstoredindex (A)
0 commit comments