Skip to content

Commit 34ac960

Browse files
authored
define spacetype for TruncationSpace (#403)
1 parent ec7af8f commit 34ac960

2 files changed

Lines changed: 6 additions & 0 deletions

File tree

src/factorizations/truncation.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ function truncspace(space::ElementarySpace; by = abs, rev::Bool = true)
2525
return TruncationSpace(space, by, rev)
2626
end
2727

28+
TensorKit.spacetype(::Type{<:TruncationSpace{S}}) where {S} = S
29+
2830
# truncate!
2931
# ---------
3032
_blocklength(d::Integer, ind) = _blocklength(Base.OneTo(d), ind)
@@ -257,10 +259,12 @@ MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByError) =
257259
MAK.findtruncated(values, strategy)
258260

259261
function MAK.findtruncated(values::SectorVector, strategy::TruncationSpace)
262+
sectortype(values) == sectortype(strategy) || throw(SectorMismatch("sectortype of truncation strategy does not match values"))
260263
blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev)
261264
return SectorDict(c => MAK.findtruncated(d, blockstrategy(c)) for (c, d) in pairs(values))
262265
end
263266
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationSpace)
267+
sectortype(values) == sectortype(strategy) || throw(SectorMismatch("sectortype of truncation strategy does not match values"))
264268
blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev)
265269
return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in pairs(values))
266270
end

test/factorizations/svd.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ for V in spacelist
212212
@test ϵ1 ϵ2
213213

214214
trunc = truncspace(space(S2, 1))
215+
@test spacetype(typeof(trunc)) == spacetype(W)
216+
@test sectortype(trunc) == sectortype(W)
215217
U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(t; trunc)
216218
@test t * Vᴴ3' U3 * S3
217219
@test isisometric(U3)

0 commit comments

Comments
 (0)