Skip to content

Commit 57ac45d

Browse files
committed
reorganize constructors
fix length check fix reshape Fix typos [skip ci]
1 parent 062081e commit 57ac45d

2 files changed

Lines changed: 182 additions & 147 deletions

File tree

src/tensors/abstracttensor.jl

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,8 @@ function Base.imag(t::AbstractTensorMap)
656656
end
657657
end
658658

659-
# Conversion to Array:
660-
#----------------------
659+
# Conversion to/from Array:
660+
#--------------------------
661661
# probably not optimized for speed, only for checking purposes
662662
function Base.convert(::Type{Array}, t::AbstractTensorMap)
663663
I = sectortype(t)
@@ -678,9 +678,55 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap)
678678
end
679679
end
680680

681+
"""
682+
project_symmetric!(t::AbstractTensorMap, data::AbstractArray) -> t
683+
684+
Project the data from a dense array `data` into the tensor map `t`. This function discards
685+
any data that does not fit the symmetry structure of `t`.
686+
"""
687+
function project_symmetric!(t::AbstractTensorMap, data::AbstractArray)
688+
# dimension check
689+
codom, dom = codomain(t), domain(t)
690+
arraysize = dims(t)
691+
matsize = (dim(codom), dim(dom))
692+
(size(data) == arraysize || size(data) == matsize) ||
693+
throw(DimensionMismatch("input data has incompatible size for the given tensor"))
694+
data = reshape(collect(data), arraysize)
695+
696+
I = sectortype(t)
697+
if I === Trivial && t isa TensorMap
698+
copy!(t.data, reshape(data, length(t.data)))
699+
return t
700+
end
701+
702+
for ((f₁, f₂), subblock) in subblocks(t)
703+
F = convert(Array, (f₁, f₂))
704+
dataslice = sview(
705+
data, axes(codomain(t), f₁.uncoupled)..., axes(domain(t), f₂.uncoupled)...
706+
)
707+
if FusionStyle(I) === UniqueFusion()
708+
Fscalar = only(F) # contains a single element
709+
scale!(subblock, dataslice, conj(Fscalar))
710+
else
711+
szbF = _interleave(size(F), size(subblock))
712+
indset1 = ntuple(identity, numind(t))
713+
indset2 = 2 .* indset1
714+
indset3 = indset2 .- 1
715+
TensorOperations.tensorcontract!(
716+
subblock,
717+
F, ((), indset1), true,
718+
sreshape(dataslice, szbF), (indset3, indset2), false,
719+
(indset1, ()),
720+
inv(dim(f₁.coupled)), false
721+
)
722+
end
723+
end
724+
725+
return t
726+
end
727+
681728
# Show and friends
682729
# ----------------
683-
684730
function Base.dims2string(V::HomSpace)
685731
str_cod = numout(V) == 0 ? "()" : join(dim.(codomain(V)), '×')
686732
str_dom = numin(V) == 0 ? "()" : join(dim.(domain(V)), '×')

0 commit comments

Comments
 (0)