@@ -9,49 +9,62 @@ for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc,
99 end
1010end
1111
12+ # TODO : move to MatrixAlgebraKit?
13+ macro check_eltype (x, y, f= :identity , g= :eltype )
14+ msg = " unexpected scalar type: "
15+ msg *= string (g) * " (" * string (x) * " ) != "
16+ if f == :identity
17+ msg *= string (g) * " (" * string (y) * " )"
18+ else
19+ msg *= string (f) * " (" * string (y) * " )"
20+ end
21+ return :($ g ($ x) == $ f ($ g ($ y)) || throw (ArgumentError ($ msg)))
22+ end
23+
1224# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap)
1325# T = scalartype(t)
1426# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
1527# end
1628
1729# Singular value decomposition
1830# ----------------------------
19- function MatrixAlgebraKit. check_input (:: typeof (svd_full!), t:: AbstractTensorMap , (U, S, Vᴴ))
31+ const T_USVᴴ = Tuple{<: AbstractTensorMap ,<: AbstractTensorMap ,<: AbstractTensorMap }
32+
33+ function MatrixAlgebraKit. check_input (:: typeof (svd_full!), t:: AbstractTensorMap ,
34+ (U, S, Vᴴ):: T_USV ᴴ)
35+ # scalartype checks
36+ @check_eltype U t
37+ @check_eltype S t real
38+ @check_eltype Vᴴ t
39+
40+ # space checks
2041 V_cod = fuse (codomain (t))
2142 V_dom = fuse (domain (t))
22-
23- (U isa AbstractTensorMap &&
24- scalartype (U) == scalartype (t) &&
25- space (U) == (codomain (t) ← V_cod)) ||
26- throw (ArgumentError (" `svd_full!` requires unitary tensor U with same `scalartype`" ))
27- (S isa AbstractTensorMap &&
28- scalartype (S) == real (scalartype (t)) &&
29- space (S) == (V_cod ← V_dom)) ||
30- throw (ArgumentError (" `svd_full!` requires rectangular tensor S with real `scalartype`" ))
31- (Vᴴ isa AbstractTensorMap &&
32- scalartype (Vᴴ) == scalartype (t) &&
33- space (Vᴴ) == (V_dom ← domain (t))) ||
34- throw (ArgumentError (" `svd_full!` requires unitary tensor Vᴴ with same `scalartype`" ))
43+ space (U) == (codomain (t) ← V_cod) ||
44+ throw (SpaceMismatch (" `svd_full!(t, (U, S, Vᴴ))` requires `space(U) == (codomain(t) ← fuse(domain(t)))`" ))
45+ space (S) == (V_cod ← V_dom) ||
46+ throw (SpaceMismatch (" `svd_full!(t, (U, S, Vᴴ))` requires `space(S) == (fuse(codomain(t)) ← fuse(domain(t))`" ))
47+ space (Vᴴ) == (V_dom ← domain (t)) ||
48+ throw (SpaceMismatch (" `svd_full!(t, (U, S, Vᴴ))` requires `space(Vᴴ) == (fuse(domain(t)) ← domain(t))`" ))
3549
3650 return nothing
3751end
3852
3953function MatrixAlgebraKit. check_input (:: typeof (svd_compact!), t:: AbstractTensorMap ,
40- (U, S, Vᴴ))
41- V_cod = V_dom = infimum (fuse (codomain (t)), fuse (domain (t)))
54+ (U, S, Vᴴ):: T_USV ᴴ)
55+ # scalartype checks
56+ @check_eltype U t
57+ @check_eltype S t real
58+ @check_eltype Vᴴ t
4259
43- (U isa AbstractTensorMap &&
44- scalartype (U) == scalartype (t) &&
45- space (U) == (codomain (t) ← V_cod)) ||
46- throw (ArgumentError (" `svd_compact!` requires isometric tensor U with same `scalartype`" ))
47- (S isa DiagonalTensorMap &&
48- scalartype (S) == real (scalartype (t)) &&
49- space (S) == (V_cod ← V_dom)) ||
50- throw (ArgumentError (" `svd_compact!` requires diagonal tensor S with real `scalartype`" ))
51- (Vᴴ isa AbstractTensorMap &&
52- scalartype (Vᴴ) == scalartype (t) &&
53- space (Vᴴ) == (V_dom ← domain (t))) ||
54- throw (ArgumentError (" `svd_compact!` requires isometric tensor Vᴴ with same `scalartype`" ))
60+ # space checks
61+ V_cod = V_dom = infimum (fuse (codomain (t)), fuse (domain (t)))
62+ space (U) == (codomain (t) ← V_cod) ||
63+ throw (SpaceMismatch (" `svd_compact!(t, (U, S, Vᴴ))` requires `space(U) == (codomain(t) ← infimum(fuse(domain(t)), fuse(codomain(t)))`" ))
64+ space (S) == (V_cod ← V_dom) ||
65+ throw (SpaceMismatch (" `svd_compact!(t, (U, S, Vᴴ))` requires diagonal `S` with `domain(S) == (infimum(fuse(codomain(t)), fuse(domain(t)))`" ))
66+ space (Vᴴ) == (V_dom ← domain (t)) ||
67+ throw (SpaceMismatch (" `svd_compact!(t, (U, S, Vᴴ))` requires `space(Vᴴ) == (infimum(fuse(domain(t)), fuse(codomain(t))) ← domain(t))`" ))
5568
5669 return nothing
5770end
0 commit comments