Skip to content

Commit 5982286

Browse files
authored
Upgrade to NamedDimsArrays.jl v0.14, ITensorBase.jl v0.5 (#54)
1 parent f0e409d commit 5982286

8 files changed

Lines changed: 33 additions & 21 deletions

File tree

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.3.6"
4+
version = "0.3.7"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -37,13 +37,13 @@ AlgorithmsInterface = "0.1"
3737
BackendSelection = "0.1.6"
3838
Combinatorics = "1"
3939
DataGraphs = "0.2.7"
40-
DiagonalArrays = "0.3.23"
40+
DiagonalArrays = "0.3.31"
4141
Dictionaries = "0.4.5"
4242
FunctionImplementations = "0.4"
4343
Graphs = "1.13.1"
4444
LinearAlgebra = "1.10"
4545
MacroTools = "0.5.16"
46-
NamedDimsArrays = "0.13"
46+
NamedDimsArrays = "0.14.2"
4747
NamedGraphs = "0.6.9, 0.7, 0.8"
4848
SimpleTraits = "0.9.5"
4949
SplitApplyCombine = "1.2.3"

src/LazyNamedDimsArrays/evaluation_order.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NamedDimsArrays: denamed, inds
1+
using NamedDimsArrays: denamed, dimnames, inds
22
using TermInterface: arguments, arity, operation
33

44
# The time complexity of evaluating `f(args...)`.
@@ -23,7 +23,7 @@ end
2323
function time_complexity(
2424
::typeof(+), t1::AbstractNamedDimsArray, t2::AbstractNamedDimsArray
2525
)
26-
@assert issetequal(inds(t1), inds(t2))
26+
@assert issetequal(dimnames(t1), dimnames(t2))
2727
return prod(denamed, size(t1))
2828
end
2929
function time_complexity(::typeof(*), c::Number, t::AbstractNamedDimsArray)
@@ -100,7 +100,7 @@ function optimize_contraction_order(alg::Algorithm"eager", a)
100100
# Penalize outer product contractions.
101101
# TODO: Still order the outer products by time complexity,
102102
# say by checking if there are only outer products left.
103-
isdisjoint(inds(a1), inds(a2)) && return typemax(Int)
103+
isdisjoint(dimnames(a1), dimnames(a2)) && return typemax(Int)
104104
return time_complexity(*, a1, a2)
105105
end
106106
contracted_arguments = [filter(((a1, a2)), arguments(a)); [a1 * a2]]

src/LazyNamedDimsArrays/lazyinterface.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NamedDimsArrays: denamed
1+
using NamedDimsArrays: denamed, dimnames, inds
22
using TermInterface: iscall, maketerm, operation, sorted_arguments
33
using WrappedUnions: unwrap
44

@@ -213,6 +213,17 @@ mul_lazy(a1::Number, a2::Number) = a1 * a2
213213
div_lazy(a1, a2::Number) = error("Not implemented.")
214214

215215
# NamedDimsArrays.jl interface.
216+
const dimnames_lazy = lazy_style(dimnames)
217+
function dimnames_lazy(a)
218+
u = unwrap(a)
219+
if !iscall(u)
220+
return dimnames(u)
221+
elseif ismul(u)
222+
return mapreduce(dimnames, symdiff, arguments(u))
223+
else
224+
return error("Variant not supported.")
225+
end
226+
end
216227
const inds_lazy = lazy_style(inds)
217228
function inds_lazy(a)
218229
u = unwrap(a)

src/LazyNamedDimsArrays/lazynameddimsarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ lazy(a::LazyNamedDimsArray) = a
2323
lazy(a::AbstractNamedDimsArray) = LazyNamedDimsArray(a)
2424
lazy(a::Mul{<:LazyNamedDimsArray}) = LazyNamedDimsArray(a)
2525

26+
NamedDimsArrays.dimnames(a::LazyNamedDimsArray) = dimnames_lazy(a)
2627
NamedDimsArrays.inds(a::LazyNamedDimsArray) = inds_lazy(a)
2728
NamedDimsArrays.denamed(a::LazyNamedDimsArray) = denamed_lazy(a)
2829

src/LazyNamedDimsArrays/symbolicnameddimsarray.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
using NamedDimsArrays: NamedDimsArray, denamed, inds, nameddims
1+
using NamedDimsArrays: NamedDimsArray, denamed, dimnames, name, nameddims
22

33
const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
44
NamedDimsArray{T, N, Parent, DimNames}
5-
function symnameddims(name, dims)
6-
return lazy(nameddims(SymbolicArray(name, denamed.(dims)), dims))
5+
function symnameddims(symname, dims)
6+
return lazy(nameddims(SymbolicArray(symname, denamed.(dims)), name.(dims)))
77
end
88
symnameddims(name) = symnameddims(name, ())
99
using AbstractTrees: AbstractTrees
@@ -16,7 +16,7 @@ function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray)
1616
end
1717
printnode_nameddims(io::IO, a::SymbolicNamedDimsArray) = AbstractTrees.printnode(io, a)
1818
function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
19-
return issetequal(inds(a), inds(b)) && denamed(a) == denamed(b)
19+
return issetequal(dimnames(a), dimnames(b)) && denamed(a) == denamed(b)
2020
end
2121
Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) = lazy(a) * lazy(b)
2222
Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) = lazy(a) * b

src/TensorNetworkGenerators/ising_network.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using DiagonalArrays: DiagonalArray
22
using Graphs: degree, dst, edges, src
33
using ..ITensorNetworksNext: @preserve_graph
44
using LinearAlgebra: Diagonal, eigen
5-
using NamedDimsArrays: apply, denamed, inds, operator, randname
5+
using NamedDimsArrays: apply, denamed, name, operator, randname
66
using NamedGraphs.GraphsExtensions: vertextype
77

88
function sqrt_ising_bond(β; J = one(β), h = zero(β), deg1::Integer, deg2::Integer)
@@ -33,16 +33,16 @@ function ising_network(
3333
(e) = get(() -> l̃[reverse(e)], l̃, e)
3434
tn = delta_network(f̃, elt, g)
3535
for v in sz_vertices
36-
a = DiagonalArray(elt[1, -1], denamed.(inds(tn[v])))
37-
tn[v] = a[inds(tn[v])...]
36+
a = DiagonalArray(elt[1, -1], denamed.(axes(tn[v])))
37+
tn[v] = a[axes(tn[v])...]
3838
end
3939
for e in edges(tn)
4040
v1 = src(e)
4141
v2 = dst(e)
4242
deg1 = degree(tn, v1)
4343
deg2 = degree(tn, v2)
4444
m = sqrt_ising_bond(β; J, h, deg1, deg2)
45-
t = operator(m, ((e),), (f(e),))
45+
t = operator(m, (name((e)),), (name(f(e)),))
4646
@preserve_graph tn[v1] = apply(t, tn[v1])
4747
@preserve_graph tn[v2] = apply(t, tn[v2])
4848
end

src/contract_network.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ function get_order(alg::Algorithm"exact", tn)
3838
end
3939
# Contraction order may or may not have indices attached, canonicalize the format
4040
# by attaching indices.
41-
subs = Dict(symnameddims(i) => symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn))
41+
subs = Dict(symnameddims(i) => symnameddims(i, Tuple(axes(tn[i]))) for i in keys(tn))
4242
return substitute(order, subs)
4343
end
4444
function contract_network(alg::Algorithm"exact", tn)
4545
order = get_order(alg, tn)
46-
syms_to_ts = Dict(symnameddims(i, Tuple(inds(tn[i]))) => lazy(tn[i]) for i in keys(tn))
46+
syms_to_ts = Dict(symnameddims(i, Tuple(axes(tn[i]))) => lazy(tn[i]) for i in keys(tn))
4747
tn_expression = substitute(order, syms_to_ts)
4848
return materialize(tn_expression)
4949
end
@@ -57,11 +57,11 @@ end
5757
# Convert the tensor network to a flat symbolic multiplication expression.
5858
function contraction_order(alg::Algorithm"flat", tn)
5959
# Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`.
60-
syms = vec([symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn)])
60+
syms = vec([symnameddims(i, Tuple(axes(tn[i]))) for i in keys(tn)])
6161
return lazy(Mul(syms))
6262
end
6363
function contraction_order(alg::Algorithm"left_associative", tn)
64-
return prod(i -> symnameddims(i, Tuple(inds(tn[i]))), keys(tn))
64+
return prod(i -> symnameddims(i, Tuple(axes(tn[i]))), keys(tn))
6565
end
6666
function contraction_order(alg::Algorithm, tn)
6767
s = contraction_order(Algorithm"flat"(), tn)

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ Aqua = "0.8.14"
2727
DiagonalArrays = "0.3.23"
2828
Dictionaries = "0.4.5"
2929
Graphs = "1.13.1"
30-
ITensorBase = "0.3, 0.4"
30+
ITensorBase = "0.5"
3131
ITensorNetworksNext = "0.3"
32-
NamedDimsArrays = "0.13"
32+
NamedDimsArrays = "0.14"
3333
NamedGraphs = "0.6.8, 0.7, 0.8"
3434
QuadGK = "2.11.2"
3535
SafeTestsets = "0.1"

0 commit comments

Comments
 (0)