@@ -34,17 +34,14 @@ function MatrixAlgebraKit.default_algorithm(
3434end
3535
3636function similar_output (
37- :: typeof (svd_compact!),
38- A,
39- s_axis:: AbstractUnitRange ,
40- alg:: MatrixAlgebraKit.AbstractAlgorithm ,
37+ :: typeof (svd_compact!), A, S_axes, alg:: MatrixAlgebraKit.AbstractAlgorithm
4138)
42- U = similar (A, axes (A, 1 ), s_axis )
39+ U = similar (A, axes (A, 1 ), S_axes[ 1 ] )
4340 T = real (eltype (A))
4441 # TODO : this should be replaced with a more general similar function that can handle setting
4542 # the blocktype and element type - something like S = similar(A, BlockType(...))
46- S = BlockSparseMatrix {T,Diagonal{T,Vector{T}}} (undef, (s_axis, s_axis) )
47- Vt = similar (A, s_axis , axes (A, 2 ))
43+ S = BlockSparseMatrix {T,Diagonal{T,Vector{T}}} (undef, S_axes )
44+ Vt = similar (A, S_axes[ 2 ] , axes (A, 2 ))
4845 return U, S, Vt
4946end
5047
@@ -56,27 +53,34 @@ function MatrixAlgebraKit.initialize_output(
5653
5754 brows = eachblockaxis (axes (A, 1 ))
5855 bcols = eachblockaxis (axes (A, 2 ))
59- s_axes = similar (brows, bmn)
56+ u_axes = similar (brows, bmn)
57+ v_axes = similar (brows, bmn)
6058
6159 # fill in values for blocks that are present
6260 bIs = collect (eachblockstoredindex (A))
6361 browIs = Int .(first .(Tuple .(bIs)))
6462 bcolIs = Int .(last .(Tuple .(bIs)))
6563 for bI in eachblockstoredindex (A)
6664 row, col = Int .(Tuple (bI))
67- s_axes[col] = argmin (length, (brows[row], bcols[col]))
65+ len = minimum (length, (brows[row], bcols[col]))
66+ u_axes[col] = brows[row][Base. OneTo (len)]
67+ v_axes[col] = bcols[col][Base. OneTo (len)]
6868 end
6969
7070 # fill in values for blocks that aren't present, pairing them in order of occurence
7171 # this is a convention, which at least gives the expected results for blockdiagonal
7272 emptyrows = setdiff (1 : bm, browIs)
7373 emptycols = setdiff (1 : bn, bcolIs)
7474 for (row, col) in zip (emptyrows, emptycols)
75- s_axes[col] = argmin (length, (brows[row], bcols[col]))
75+ len = minimum (length, (brows[row], bcols[col]))
76+ u_axes[col] = brows[row][Base. OneTo (len)]
77+ v_axes[col] = bcols[col][Base. OneTo (len)]
7678 end
7779
78- s_axis = mortar_axis (s_axes)
79- U, S, Vt = similar_output (svd_compact!, A, s_axis, alg)
80+ u_axis = mortar_axis (u_axes)
81+ v_axis = mortar_axis (v_axes)
82+ S_axes = (u_axis, v_axis)
83+ U, S, Vt = similar_output (svd_compact!, A, S_axes, alg)
8084
8185 # allocate output
8286 for bI in eachblockstoredindex (A)
@@ -96,12 +100,12 @@ function MatrixAlgebraKit.initialize_output(
96100end
97101
98102function similar_output (
99- :: typeof (svd_full!), A, s_axis :: AbstractUnitRange , alg:: MatrixAlgebraKit.AbstractAlgorithm
103+ :: typeof (svd_full!), A, S_axes , alg:: MatrixAlgebraKit.AbstractAlgorithm
100104)
101- U = similar (A, axes (A, 1 ), s_axis )
105+ U = similar (A, axes (A, 1 ), S_axes[ 1 ] )
102106 T = real (eltype (A))
103- S = similar (A, T, (s_axis, axes (A, 2 )) )
104- Vt = similar (A, axes (A, 2 ) , axes (A, 2 ))
107+ S = similar (A, T, S_axes )
108+ Vt = similar (A, S_axes[ 2 ] , axes (A, 2 ))
105109 return U, S, Vt
106110end
107111
@@ -111,30 +115,31 @@ function MatrixAlgebraKit.initialize_output(
111115 bm, bn = blocksize (A)
112116
113117 brows = eachblockaxis (axes (A, 1 ))
114- s_axes = similar (brows)
118+ u_axes = similar (brows)
115119
116120 # fill in values for blocks that are present
117121 bIs = collect (eachblockstoredindex (A))
118122 browIs = Int .(first .(Tuple .(bIs)))
119123 bcolIs = Int .(last .(Tuple .(bIs)))
120124 for bI in eachblockstoredindex (A)
121125 row, col = Int .(Tuple (bI))
122- s_axes [col] = brows[row]
126+ u_axes [col] = brows[row]
123127 end
124128
125129 # fill in values for blocks that aren't present, pairing them in order of occurence
126130 # this is a convention, which at least gives the expected results for blockdiagonal
127131 emptyrows = setdiff (1 : bm, browIs)
128132 emptycols = setdiff (1 : bn, bcolIs)
129133 for (row, col) in zip (emptyrows, emptycols)
130- s_axes [col] = brows[row]
134+ u_axes [col] = brows[row]
131135 end
132136 for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
133- s_axes [bn + i] = brows[emptyrows[k]]
137+ u_axes [bn + i] = brows[emptyrows[k]]
134138 end
135139
136- s_axis = mortar_axis (s_axes)
137- U, S, Vt = similar_output (svd_full!, A, s_axis, alg)
140+ u_axis = mortar_axis (u_axes)
141+ S_axes = (u_axis, axes (A, 2 ))
142+ U, S, Vt = similar_output (svd_full!, A, S_axes, alg)
138143
139144 # allocate output
140145 for bI in eachblockstoredindex (A)
0 commit comments