|
| 1 | +module MPSKitAdaptExt |
| 2 | + |
| 3 | +using TensorKit: space, spacetype |
| 4 | +using MPSKit |
| 5 | +using BlockTensorKit: nonzero_pairs |
| 6 | +using Adapt |
| 7 | + |
| 8 | +function Adapt.adapt_structure(to, mps::FiniteMPS) |
| 9 | + ad = adapt(to) |
| 10 | + adapt_not_missing(x) = ismissing(x) ? x : ad(x) |
| 11 | + |
| 12 | + TA = Base.promote_op(ad, MPSKit.site_type(mps)) |
| 13 | + TB = Base.promote_op(ad, MPSKit.bond_type(mps)) |
| 14 | + |
| 15 | + ALs = map!(adapt_not_missing, similar(mps.ALs, Union{Missing, TA}), mps.ALs) |
| 16 | + ARs = map!(adapt_not_missing, similar(mps.ARs, Union{Missing, TA}), mps.ARs) |
| 17 | + ACs = map!(adapt_not_missing, similar(mps.ACs, Union{Missing, TA}), mps.ACs) |
| 18 | + Cs = map!(adapt_not_missing, similar(mps.Cs, Union{Missing, TB}), mps.Cs) |
| 19 | + |
| 20 | + return FiniteMPS{TA, TB}(ALs, ARs, ACs, Cs) |
| 21 | +end |
| 22 | + |
| 23 | +function Adapt.adapt_structure(to, mps::InfiniteMPS) |
| 24 | + ad = adapt(to) |
| 25 | + AL = map(ad, mps.AL) |
| 26 | + AR = map(ad, mps.AR) |
| 27 | + C = map(ad, mps.C) |
| 28 | + AC = map(ad, mps.AC) |
| 29 | + |
| 30 | + return InfiniteMPS{eltype(AL), eltype(C)}(AL, AR, C, AC) |
| 31 | +end |
| 32 | + |
| 33 | +Adapt.adapt_structure(to, mpo::MPO) = MPO(map(adapt(to), mpo.O)) |
| 34 | + |
| 35 | +function Adapt.adapt_structure(::Type{TorA}, W::MPSKit.JordanMPOTensor) where {TorA <: Union{Number, DenseVector{<:Number}}} |
| 36 | + TT = MPSKit.jordanmpotensortype(spacetype(W), TorA) |
| 37 | + W′ = TT(undef, space(W)) |
| 38 | + ad = adapt(TorA) |
| 39 | + |
| 40 | + for (k, v) in nonzero_pairs(W.A) |
| 41 | + W′.A[k] = ad(v) |
| 42 | + end |
| 43 | + for (k, v) in nonzero_pairs(W.B) |
| 44 | + W′.B[k] = ad(v) |
| 45 | + end |
| 46 | + for (k, v) in nonzero_pairs(W.C) |
| 47 | + W′.C[k] = ad(v) |
| 48 | + end |
| 49 | + for (k, v) in nonzero_pairs(W.D) |
| 50 | + W′.D[k] = ad(v) |
| 51 | + end |
| 52 | + |
| 53 | + return W′ |
| 54 | +end |
| 55 | +Adapt.adapt_structure(to, mpo::MPOHamiltonian) = MPOHamiltonian(map(adapt(to), mpo.W)) |
| 56 | + |
| 57 | +end |
0 commit comments