Skip to content

Commit ac94a5a

Browse files
committed
Improve svd error messages
1 parent fb9fafd commit ac94a5a

1 file changed

Lines changed: 41 additions & 28 deletions

File tree

src/tensors/matrixalgebrakit.jl

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,49 +9,62 @@ for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc,
99
end
1010
end
1111

12+
# TODO: move to MatrixAlgebraKit?
13+
macro check_eltype(x, y, f=:identity, g=:eltype)
14+
msg = "unexpected scalar type: "
15+
msg *= string(g) * "(" * string(x) * ") != "
16+
if f == :identity
17+
msg *= string(g) * "(" * string(y) * ")"
18+
else
19+
msg *= string(f) * "(" * string(y) * ")"
20+
end
21+
return :($g($x) == $f($g($y)) || throw(ArgumentError($msg)))
22+
end
23+
1224
# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap)
1325
# T = scalartype(t)
1426
# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
1527
# end
1628

1729
# Singular value decomposition
1830
# ----------------------------
19-
function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ))
31+
const T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap}
32+
33+
function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap,
34+
(U, S, Vᴴ)::T_USVᴴ)
35+
# scalartype checks
36+
@check_eltype U t
37+
@check_eltype S t real
38+
@check_eltype Vᴴ t
39+
40+
# space checks
2041
V_cod = fuse(codomain(t))
2142
V_dom = fuse(domain(t))
22-
23-
(U isa AbstractTensorMap &&
24-
scalartype(U) == scalartype(t) &&
25-
space(U) == (codomain(t) V_cod)) ||
26-
throw(ArgumentError("`svd_full!` requires unitary tensor U with same `scalartype`"))
27-
(S isa AbstractTensorMap &&
28-
scalartype(S) == real(scalartype(t)) &&
29-
space(S) == (V_cod V_dom)) ||
30-
throw(ArgumentError("`svd_full!` requires rectangular tensor S with real `scalartype`"))
31-
(Vᴴ isa AbstractTensorMap &&
32-
scalartype(Vᴴ) == scalartype(t) &&
33-
space(Vᴴ) == (V_dom domain(t))) ||
34-
throw(ArgumentError("`svd_full!` requires unitary tensor Vᴴ with same `scalartype`"))
43+
space(U) == (codomain(t) V_cod) ||
44+
throw(SpaceMismatch("`svd_full!(t, (U, S, Vᴴ))` requires `space(U) == (codomain(t) ← fuse(domain(t)))`"))
45+
space(S) == (V_cod V_dom) ||
46+
throw(SpaceMismatch("`svd_full!(t, (U, S, Vᴴ))` requires `space(S) == (fuse(codomain(t)) ← fuse(domain(t))`"))
47+
space(Vᴴ) == (V_dom domain(t)) ||
48+
throw(SpaceMismatch("`svd_full!(t, (U, S, Vᴴ))` requires `space(Vᴴ) == (fuse(domain(t)) ← domain(t))`"))
3549

3650
return nothing
3751
end
3852

3953
function MatrixAlgebraKit.check_input(::typeof(svd_compact!), t::AbstractTensorMap,
40-
(U, S, Vᴴ))
41-
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
54+
(U, S, Vᴴ)::T_USVᴴ)
55+
# scalartype checks
56+
@check_eltype U t
57+
@check_eltype S t real
58+
@check_eltype Vᴴ t
4259

43-
(U isa AbstractTensorMap &&
44-
scalartype(U) == scalartype(t) &&
45-
space(U) == (codomain(t) V_cod)) ||
46-
throw(ArgumentError("`svd_compact!` requires isometric tensor U with same `scalartype`"))
47-
(S isa DiagonalTensorMap &&
48-
scalartype(S) == real(scalartype(t)) &&
49-
space(S) == (V_cod V_dom)) ||
50-
throw(ArgumentError("`svd_compact!` requires diagonal tensor S with real `scalartype`"))
51-
(Vᴴ isa AbstractTensorMap &&
52-
scalartype(Vᴴ) == scalartype(t) &&
53-
space(Vᴴ) == (V_dom domain(t))) ||
54-
throw(ArgumentError("`svd_compact!` requires isometric tensor Vᴴ with same `scalartype`"))
60+
# space checks
61+
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
62+
space(U) == (codomain(t) V_cod) ||
63+
throw(SpaceMismatch("`svd_compact!(t, (U, S, Vᴴ))` requires `space(U) == (codomain(t) ← infimum(fuse(domain(t)), fuse(codomain(t)))`"))
64+
space(S) == (V_cod V_dom) ||
65+
throw(SpaceMismatch("`svd_compact!(t, (U, S, Vᴴ))` requires diagonal `S` with `domain(S) == (infimum(fuse(codomain(t)), fuse(domain(t)))`"))
66+
space(Vᴴ) == (V_dom domain(t)) ||
67+
throw(SpaceMismatch("`svd_compact!(t, (U, S, Vᴴ))` requires `space(Vᴴ) == (infimum(fuse(domain(t)), fuse(codomain(t))) ← domain(t))`"))
5568

5669
return nothing
5770
end

0 commit comments

Comments
 (0)