Skip to content

Commit 31cab87

Browse files
committed
formatter
1 parent 218dd1e commit 31cab87

4 files changed

Lines changed: 95 additions & 43 deletions

File tree

src/fusedgradedmatrix.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ LinearAlgebra.istriu(A::FusedGradedMatrix) = all(LinearAlgebra.istriu, A.blocks)
211211
LinearAlgebra.istril(A::FusedGradedMatrix) = all(LinearAlgebra.istril, A.blocks)
212212
LinearAlgebra.isposdef(A::FusedGradedMatrix) = all(LinearAlgebra.isposdef, A.blocks)
213213

214-
215214
# ======================== similar ========================
216215

217216
function Base.similar(m::FusedGradedMatrix, ::Type{T}) where {T}

src/fusedgradedvector.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ struct FusedGradedVector{T, D <: AbstractVector{T}, S <: SectorRange} <:
9292
sectors::Vector{S},
9393
blocks::Vector{D}
9494
) where {T, D <: AbstractVector{T}, S <: SectorRange}
95-
length(sectors) == length(blocks) || throw(ArgumentError("sectors and blocks must have the same length"))
95+
length(sectors) == length(blocks) ||
96+
throw(ArgumentError("sectors and blocks must have the same length"))
9697
issorted(sectors) || throw(ArgumentError("sectors must be sorted"))
9798
allunique(sectors) || throw(ArgumentError("sectors must be unique"))
9899
return new{T, D, S}(sectors, blocks)

src/matrixalgebrakit.jl

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,44 @@ for f in [
1414
:left_polar, :right_polar,
1515
]
1616
f! = Symbol(f, :!)
17-
@eval function MAK.default_algorithm(::typeof(MAK.$f!), ::Type{T}; kwargs...) where {T <: FusedGradedMatrix}
18-
return GradedBlockAlgorithm(MAK.default_algorithm(MAK.$f!, datatype(BlockSparseArrays.blocktype(T)); kwargs...))
17+
@eval function MAK.default_algorithm(
18+
::typeof(MAK.$f!),
19+
::Type{T};
20+
kwargs...
21+
) where {T <: FusedGradedMatrix}
22+
return GradedBlockAlgorithm(
23+
MAK.default_algorithm(
24+
MAK.$f!,
25+
datatype(BlockSparseArrays.blocktype(T));
26+
kwargs...
27+
)
28+
)
1929
end
2030

2131
@eval function MAK.copy_input(::typeof(MAK.$f), A::FusedGradedMatrix)
22-
return FusedGradedMatrix(A.sectors, map(Base.Fix1(MAK.copy_input, MAK.$f), A.blocks))
32+
return FusedGradedMatrix(
33+
A.sectors,
34+
map(Base.Fix1(MAK.copy_input, MAK.$f), A.blocks)
35+
)
2336
end
2437

25-
@eval function MAK.check_input(::typeof(MAK.$f!), A::FusedGradedMatrix, F::Tuple, alg::GradedBlockAlgorithm)
38+
@eval function MAK.check_input(
39+
::typeof(MAK.$f!),
40+
A::FusedGradedMatrix,
41+
F::Tuple,
42+
alg::GradedBlockAlgorithm
43+
)
2644
for f in F
2745
A.sectors == f.sectors || throw(ArgumentError("non-matching sectors"))
2846
end
2947
return nothing
3048
end
31-
@eval function MAK.check_input(::typeof(MAK.$f!), A::FusedGradedMatrix, F, alg::GradedBlockAlgorithm)
49+
@eval function MAK.check_input(
50+
::typeof(MAK.$f!),
51+
A::FusedGradedMatrix,
52+
F,
53+
alg::GradedBlockAlgorithm
54+
)
3255
A.sectors == F.sectors || throw(ArgumentError("non-matching sectors"))
3356
return nothing
3457
end
@@ -43,29 +66,43 @@ _ensure_inplace!(F::NTuple{N}, F′::NTuple{N}) where {N} = _ensure_inplace!.(F,
4366

4467
# Single-output: null-space functions return FusedGradedMatrix
4568
for f! in [:qr_null!, :lq_null!]
46-
@eval function MAK.initialize_output(::typeof(MAK.$f!), A::FusedGradedMatrix, alg::GradedBlockAlgorithm)
47-
return FusedGradedMatrix(A.sectors, map(a -> MAK.initialize_output(MAK.$f!, a, alg.alg), A.blocks))
69+
@eval function MAK.initialize_output(
70+
::typeof(MAK.$f!),
71+
A::FusedGradedMatrix,
72+
alg::GradedBlockAlgorithm
73+
)
74+
return FusedGradedMatrix(
75+
A.sectors,
76+
map(a -> MAK.initialize_output(MAK.$f!, a, alg.alg), A.blocks)
77+
)
4878
end
4979
@eval function MAK.$f!(A::FusedGradedMatrix, F, alg::GradedBlockAlgorithm)
5080
MAK.check_input(MAK.$f!, A, F, alg)
5181
foreach(A.blocks, F.blocks) do a, f
5282
f′ = MAK.$f!(a, f, alg.alg)
53-
_ensure_inplace!(f′, f)
83+
return _ensure_inplace!(f′, f)
5484
end
5585
return F
5686
end
5787
end
5888

5989
# Single-output: vals functions return FusedGradedVector
6090
for f! in [:svd_vals!, :eig_vals!, :eigh_vals!]
61-
@eval function MAK.initialize_output(::typeof(MAK.$f!), A::FusedGradedMatrix, alg::GradedBlockAlgorithm)
62-
return FusedGradedVector(A.sectors, map(a -> MAK.initialize_output(MAK.$f!, a, alg.alg), A.blocks))
91+
@eval function MAK.initialize_output(
92+
::typeof(MAK.$f!),
93+
A::FusedGradedMatrix,
94+
alg::GradedBlockAlgorithm
95+
)
96+
return FusedGradedVector(
97+
A.sectors,
98+
map(a -> MAK.initialize_output(MAK.$f!, a, alg.alg), A.blocks)
99+
)
63100
end
64101
@eval function MAK.$f!(A::FusedGradedMatrix, F, alg::GradedBlockAlgorithm)
65102
MAK.check_input(MAK.$f!, A, F, alg)
66103
foreach(A.blocks, F.blocks) do a, f
67104
f′ = MAK.$f!(a, f, alg.alg)
68-
_ensure_inplace!(f′, f)
105+
return _ensure_inplace!(f′, f)
69106
end
70107
return F
71108
end
@@ -77,7 +114,11 @@ for f! in [
77114
:eig_full!, :eigh_full!, :svd_compact!, :svd_full!,
78115
:left_polar!, :right_polar!,
79116
]
80-
@eval function MAK.initialize_output(::typeof(MAK.$f!), A::FusedGradedMatrix, alg::GradedBlockAlgorithm)
117+
@eval function MAK.initialize_output(
118+
::typeof(MAK.$f!),
119+
A::FusedGradedMatrix,
120+
alg::GradedBlockAlgorithm
121+
)
81122
sectors = A.sectors
82123
blocks = map(a -> MAK.initialize_output(MAK.$f!, a, alg.alg), A.blocks)
83124
narg = $(startswith(string(f!), "svd") ? 3 : 2)
@@ -90,15 +131,22 @@ for f! in [
90131
MAK.check_input(MAK.$f!, A, F, alg)
91132
foreach(A.blocks, getproperty.(F, :blocks)...) do a, f...
92133
f′ = MAK.$f!(a, f, alg.alg)
93-
_ensure_inplace!(f′, f)
134+
return _ensure_inplace!(f′, f)
94135
end
95136
return F
96137
end
97138
end
98139

99140
# Matrix properties
100141
# -----------------
101-
for f in [:isunitary, :isisometric, :is_left_isometric, :is_right_isometric, :ishermitian, :isantihermitian]
142+
for f in [
143+
:isunitary,
144+
:isisometric,
145+
:is_left_isometric,
146+
:is_right_isometric,
147+
:ishermitian,
148+
:isantihermitian,
149+
]
102150
@eval function MAK.$f(A::FusedGradedMatrix; kwargs...)
103151
return all(x -> MAK.$f(x; kwargs...), A.blocks)
104152
end
@@ -177,8 +225,9 @@ function MAK.findtruncated(v::FusedGradedVector, strategy::MAK.TruncationByOrder
177225
return kept
178226
end
179227
# SVD values are sorted descending within each block but we still need a cross-block comparison
180-
MAK.findtruncated_svd(v::FusedGradedVector, strategy::MAK.TruncationByOrder) =
181-
MAK.findtruncated(v, strategy)
228+
function MAK.findtruncated_svd(v::FusedGradedVector, strategy::MAK.TruncationByOrder)
229+
return MAK.findtruncated(v, strategy)
230+
end
182231

183232
# TruncationByError (truncerror): global cumulative error budget, discard smallest first
184233
function MAK.findtruncated(v::FusedGradedVector, strategy::MAK.TruncationByError)
@@ -214,8 +263,9 @@ function MAK.findtruncated(v::FusedGradedVector, strategy::MAK.TruncationByError
214263
end
215264

216265
# TruncationByError: disambiguate against MAK's findtruncated_svd(::AbstractVector, ::TruncationByError)
217-
MAK.findtruncated_svd(v::FusedGradedVector, strategy::MAK.TruncationByError) =
218-
MAK.findtruncated(v, strategy)
266+
function MAK.findtruncated_svd(v::FusedGradedVector, strategy::MAK.TruncationByError)
267+
return MAK.findtruncated(v, strategy)
268+
end
219269

220270
# TruncationIntersection: intersect per-block results from each component strategy
221271
function MAK.findtruncated(v::FusedGradedVector, strategy::MAK.TruncationIntersection)
@@ -238,7 +288,7 @@ end
238288
function MAK.truncate(
239289
::typeof(MAK.svd_trunc!),
240290
(U, S, Vᴴ)::NTuple{3, FusedGradedMatrix},
241-
strategy::MAK.TruncationStrategy,
291+
strategy::MAK.TruncationStrategy
242292
)
243293
sv = MAK.diagview(S)
244294
ind = MAK.findtruncated_svd(sv, strategy)
@@ -264,7 +314,7 @@ for f! in (:eigh_trunc!, :eig_trunc!)
264314
@eval function MAK.truncate(
265315
::typeof(MAK.$f!),
266316
(D, V)::NTuple{2, FusedGradedMatrix},
267-
strategy::MAK.TruncationStrategy,
317+
strategy::MAK.TruncationStrategy
268318
)
269319
ev = MAK.diagview(D)
270320
ind = MAK.findtruncated(ev, strategy)

test/test_factorizations.jl

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
using GradedArrays: FusedGradedMatrix, FusedGradedVector, GradedBlockAlgorithm, U1, Z2
2-
using LinearAlgebra: Diagonal, I, eigvals, norm, istril, istriu, isposdef
31
import MatrixAlgebraKit as MAK
2+
using GradedArrays: FusedGradedMatrix, FusedGradedVector, GradedBlockAlgorithm, U1, Z2
3+
using LinearAlgebra: Diagonal, I, eigvals, isposdef, istril, istriu, norm
44
using MatrixAlgebraKit: isisometric, isunitary
55
using Test: @test, @testset
66

@@ -30,11 +30,13 @@ function has_positive_diagonal(A)
3030
all((zero(real(T))), imag(diagview(A)))
3131
end
3232
end
33-
isleftnull(N, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
34-
isapprox(norm(A' * N), 0; atol = max(atol, norm(A) * rtol))
33+
function isleftnull(N, A; atol::Real = 0, rtol::Real = precision(eltype(A)))
34+
return isapprox(norm(A' * N), 0; atol = max(atol, norm(A) * rtol))
35+
end
3536

36-
isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
37-
isapprox(norm(A * Nᴴ'), 0; atol = max(atol, norm(A) * rtol))
37+
function isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A)))
38+
return isapprox(norm(A * Nᴴ'), 0; atol = max(atol, norm(A) * rtol))
39+
end
3840

3941
@testset "Factorizations" begin
4042

@@ -108,7 +110,7 @@ isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
108110
@test isapprox(
109111
sort(S.blocks[i]; rev = true),
110112
sort(MAK.diagview(S2.blocks[i]); rev = true);
111-
atol = 1.0e-10,
113+
atol = 1.0e-10
112114
)
113115
end
114116
end
@@ -156,7 +158,6 @@ isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
156158
end
157159
end
158160

159-
160161
# -----------------------------------------------------------------------
161162
@testset "LQ" begin
162163
@testset "compact" begin
@@ -225,7 +226,7 @@ isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
225226
@test isapprox(
226227
sort(D.blocks[i]; by = real),
227228
sort(MAK.diagview(D2.blocks[i]); by = real);
228-
atol = 1.0e-10,
229+
atol = 1.0e-10
229230
)
230231
end
231232
end
@@ -256,7 +257,7 @@ isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
256257
@test isapprox(
257258
sort(real.(D.blocks[i])),
258259
sort(real.(MAK.diagview(D2.blocks[i])));
259-
atol = 1.0e-10,
260+
atol = 1.0e-10
260261
)
261262
end
262263
end
@@ -295,14 +296,14 @@ isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
295296
using MatrixAlgebraKit: notrunc, truncrank, trunctol, truncerror
296297

297298
@testset "notrunc" begin
298-
U, S, Vᴴ, ε = MAK.svd_trunc(A_rect; trunc=notrunc())
299+
U, S, Vᴴ, ε = MAK.svd_trunc(A_rect; trunc = notrunc())
299300
@test U isa FusedGradedMatrix
300301
@test S isa FusedGradedMatrix
301302
@test Vᴴ isa FusedGradedMatrix
302303
@test ε 0 atol = precision(eltype(A_rect))
303304
@test A_rect U * S * Vᴴ
304305
@test isisometric(U)
305-
@test isisometric(Vᴴ; side=:right)
306+
@test isisometric(Vᴴ; side = :right)
306307

307308
# same sectors as compact SVD
308309
U0, S0, Vᴴ0 = MAK.svd_compact(A_rect)
@@ -312,19 +313,19 @@ isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
312313

313314
@testset "truncrank" begin
314315
maxrank = 4
315-
U, S, Vᴴ, ε = MAK.svd_trunc(A_rect; trunc=truncrank(maxrank))
316+
U, S, Vᴴ, ε = MAK.svd_trunc(A_rect; trunc = truncrank(maxrank))
316317
@test U isa FusedGradedMatrix
317318
# total number of kept singular values ≤ maxrank
318319
@test sum(size(b, 2) for b in U.blocks) <= maxrank
319320
# reconstruction error ≈ reported truncation error
320321
@test norm(A_rect - U * S * Vᴴ) ε atol = precision(eltype(A_rect))
321322
@test isisometric(U)
322-
@test isisometric(Vᴴ; side=:right)
323+
@test isisometric(Vᴴ; side = :right)
323324
end
324325

325326
@testset "trunctol" begin
326327
atol = 0.5
327-
U, S, Vᴴ, ε = MAK.svd_trunc(A_rect; trunc=trunctol(; atol))
328+
U, S, Vᴴ, ε = MAK.svd_trunc(A_rect; trunc = trunctol(; atol))
328329
@test U isa FusedGradedMatrix
329330
# all kept singular values are above the tolerance
330331
for b in S.blocks
@@ -335,14 +336,15 @@ isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
335336

336337
@testset "truncerror" begin
337338
atol = 0.3
338-
U, S, Vᴴ, ε = MAK.svd_trunc(A_rect; trunc=truncerror(; atol))
339+
U, S, Vᴴ, ε = MAK.svd_trunc(A_rect; trunc = truncerror(; atol))
339340
@test U isa FusedGradedMatrix
340341
@test ε <= atol + precision(eltype(A_rect))
341342
@test norm(A_rect - U * S * Vᴴ) ε atol = precision(eltype(A_rect))
342343
end
343344

344345
@testset "combined (truncrank & trunctol)" begin
345-
U, S, Vᴴ, ε = MAK.svd_trunc(A_rect; trunc=truncrank(3) & trunctol(; atol=0.3))
346+
U, S, Vᴴ, ε =
347+
MAK.svd_trunc(A_rect; trunc = truncrank(3) & trunctol(; atol = 0.3))
346348
@test U isa FusedGradedMatrix
347349
@test sum(size(b, 2) for b in U.blocks) <= 3
348350
for b in S.blocks
@@ -351,7 +353,7 @@ isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
351353
end
352354

353355
@testset "svd_trunc_no_error" begin
354-
U, S, Vᴴ = MAK.svd_trunc_no_error(A_rect; trunc=truncrank(3))
356+
U, S, Vᴴ = MAK.svd_trunc_no_error(A_rect; trunc = truncrank(3))
355357
@test U isa FusedGradedMatrix
356358
@test sum(size(b, 2) for b in U.blocks) <= 3
357359
end
@@ -362,7 +364,7 @@ isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
362364
using MatrixAlgebraKit: notrunc, truncrank, trunctol, truncerror
363365

364366
@testset "notrunc" begin
365-
D, V, ε = MAK.eigh_trunc(A_herm; trunc=notrunc())
367+
D, V, ε = MAK.eigh_trunc(A_herm; trunc = notrunc())
366368
@test D isa FusedGradedMatrix
367369
@test V isa FusedGradedMatrix
368370
@test ε 0 atol = precision(eltype(A_herm))
@@ -373,15 +375,15 @@ isrightnull(Nᴴ, A; atol::Real = 0, rtol::Real = precision(eltype(A))) =
373375

374376
@testset "truncrank" begin
375377
maxrank = 5
376-
D, V, ε = MAK.eigh_trunc(A_herm; trunc=truncrank(maxrank))
378+
D, V, ε = MAK.eigh_trunc(A_herm; trunc = truncrank(maxrank))
377379
@test D isa FusedGradedMatrix
378380
@test sum(size(b, 2) for b in V.blocks) <= maxrank
379381
@test isisometric(V)
380382
end
381383

382384
@testset "trunctol (keep largest by abs)" begin
383385
atol = 0.3
384-
D, V, ε = MAK.eigh_trunc(A_herm; trunc=trunctol(; atol))
386+
D, V, ε = MAK.eigh_trunc(A_herm; trunc = trunctol(; atol))
385387
@test D isa FusedGradedMatrix
386388
for b in D.blocks
387389
@test all((atol) abs, MAK.diagview(b))

0 commit comments

Comments
 (0)