Skip to content

Commit bc32491

Browse files
authored
Merge branch 'main' into bp
2 parents 2f5c783 + 633c54e commit bc32491

12 files changed

Lines changed: 61 additions & 62 deletions

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.3.1"
4+
version = "0.3.3"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -10,9 +10,9 @@ AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d"
1010
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
1111
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1212
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
13-
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
1413
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
1514
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
15+
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
1616
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1717
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -37,13 +37,14 @@ AlgorithmsInterface = "0.1"
3737
BackendSelection = "0.1.6"
3838
Combinatorics = "1"
3939
DataGraphs = "0.2.7"
40-
DerivableInterfaces = "0.5.5"
4140
DiagonalArrays = "0.3.23"
4241
Dictionaries = "0.4.5"
42+
FunctionImplementations = "0.3"
4343
Graphs = "1.13.1"
4444
LinearAlgebra = "1.10"
4545
MacroTools = "0.5.16"
4646
NamedDimsArrays = "0.8, 0.9, 0.10, 0.11"
47+
NamedDimsArrays = "0.13"
4748
NamedGraphs = "0.6.9, 0.7, 0.8"
4849
SimpleTraits = "0.9.5"
4950
SplitApplyCombine = "1.2.3"

src/ITensorNetworksNext.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ include("beliefpropagation/abstractbeliefpropagationcache.jl")
1313
include("beliefpropagation/beliefpropagationcache.jl")
1414
include("beliefpropagation/beliefpropagationproblem.jl")
1515

16-
end
16+
end

src/LazyNamedDimsArrays/evaluation_order.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NamedDimsArrays: dename, inds
1+
using NamedDimsArrays: denamed, inds
22
using TermInterface: arguments, arity, operation
33

44
# The time complexity of evaluating `f(args...)`.
@@ -18,16 +18,16 @@ using NamedDimsArrays: AbstractNamedDimsArray
1818
function time_complexity(
1919
::typeof(*), t1::AbstractNamedDimsArray, t2::AbstractNamedDimsArray
2020
)
21-
return prod(length dename, (inds(t1) inds(t2)))
21+
return prod(length denamed, (inds(t1) inds(t2)))
2222
end
2323
function time_complexity(
2424
::typeof(+), t1::AbstractNamedDimsArray, t2::AbstractNamedDimsArray
2525
)
2626
@assert issetequal(inds(t1), inds(t2))
27-
return prod(dename, size(t1))
27+
return prod(denamed, size(t1))
2828
end
2929
function time_complexity(::typeof(*), c::Number, t::AbstractNamedDimsArray)
30-
return prod(dename, size(t))
30+
return prod(denamed, size(t))
3131
end
3232
function time_complexity(::typeof(*), t::AbstractNamedDimsArray, c::Number)
3333
return time_complexity(*, c, t)

src/LazyNamedDimsArrays/lazybroadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NamedDimsArrays: AbstractNamedDimsArrayStyle
1+
using NamedDimsArrays.Broadcast: AbstractNamedDimsArrayStyle
22

33
# Lazy broadcasting.
44
struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end

src/LazyNamedDimsArrays/lazyinterface.jl

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NamedDimsArrays: dename
1+
using NamedDimsArrays: denamed
22
using TermInterface: iscall, maketerm, operation, sorted_arguments
33
using WrappedUnions: unwrap
44

@@ -23,22 +23,19 @@ opwalk(opmap, a) = walk(opmap, identity, a)
2323
argwalk(argmap, a) = walk(identity, argmap, a)
2424

2525
# Generic lazy functionality.
26-
using DerivableInterfaces: AbstractArrayInterface, InterfaceFunction
27-
struct LazyInterface{N} <: AbstractArrayInterface{N} end
28-
LazyInterface() = LazyInterface{Any}()
29-
LazyInterface(::Val{N}) where {N} = LazyInterface{N}()
30-
LazyInterface{M}(::Val{N}) where {M, N} = LazyInterface{N}()
31-
const lazy_interface = LazyInterface()
26+
using FunctionImplementations: AbstractArrayStyle
27+
struct LazyStyle <: AbstractArrayStyle end
28+
const lazy_style = LazyStyle()
3229

33-
const maketerm_lazy = lazy_interface(maketerm)
30+
const maketerm_lazy = lazy_style(maketerm)
3431
function maketerm_lazy(type::Type, head, args, metadata)
3532
if head *
3633
return type(maketerm(Mul, head, args, metadata))
3734
else
3835
return error("Only mul supported right now.")
3936
end
4037
end
41-
const getindex_lazy = lazy_interface(getindex)
38+
const getindex_lazy = lazy_style(getindex)
4239
function getindex_lazy(a::AbstractArray, I...)
4340
u = unwrap(a)
4441
if !iscall(u)
@@ -47,7 +44,7 @@ function getindex_lazy(a::AbstractArray, I...)
4744
return error("Indexing into expression not supported.")
4845
end
4946
end
50-
const arguments_lazy = lazy_interface(arguments)
47+
const arguments_lazy = lazy_style(arguments)
5148
function arguments_lazy(a)
5249
u = unwrap(a)
5350
if !iscall(u)
@@ -59,17 +56,17 @@ function arguments_lazy(a)
5956
end
6057
end
6158
using TermInterface: children
62-
const children_lazy = lazy_interface(children)
59+
const children_lazy = lazy_style(children)
6360
children_lazy(a) = arguments(a)
6461
using TermInterface: head
65-
const head_lazy = lazy_interface(head)
62+
const head_lazy = lazy_style(head)
6663
head_lazy(a) = operation(a)
67-
const iscall_lazy = lazy_interface(iscall)
64+
const iscall_lazy = lazy_style(iscall)
6865
iscall_lazy(a) = iscall(unwrap(a))
6966
using TermInterface: isexpr
70-
const isexpr_lazy = lazy_interface(isexpr)
67+
const isexpr_lazy = lazy_style(isexpr)
7168
isexpr_lazy(a) = iscall(a)
72-
const operation_lazy = lazy_interface(operation)
69+
const operation_lazy = lazy_style(operation)
7370
function operation_lazy(a)
7471
u = unwrap(a)
7572
if !iscall(u)
@@ -80,7 +77,7 @@ function operation_lazy(a)
8077
return error("Variant not supported.")
8178
end
8279
end
83-
const sorted_arguments_lazy = lazy_interface(sorted_arguments)
80+
const sorted_arguments_lazy = lazy_style(sorted_arguments)
8481
function sorted_arguments_lazy(a)
8582
u = unwrap(a)
8683
if !iscall(u)
@@ -92,12 +89,12 @@ function sorted_arguments_lazy(a)
9289
end
9390
end
9491
using TermInterface: sorted_children
95-
const sorted_children_lazy = lazy_interface(sorted_children)
92+
const sorted_children_lazy = lazy_style(sorted_children)
9693
sorted_children_lazy(a) = sorted_arguments(a)
97-
const ismul_lazy = lazy_interface(ismul)
94+
const ismul_lazy = lazy_style(ismul)
9895
ismul_lazy(a) = ismul(unwrap(a))
9996
using AbstractTrees: AbstractTrees
100-
const abstracttrees_children_lazy = lazy_interface(AbstractTrees.children)
97+
const abstracttrees_children_lazy = lazy_style(AbstractTrees.children)
10198
function abstracttrees_children_lazy(a)
10299
if !iscall(a)
103100
return ()
@@ -106,7 +103,7 @@ function abstracttrees_children_lazy(a)
106103
end
107104
end
108105
using AbstractTrees: nodevalue
109-
const nodevalue_lazy = lazy_interface(nodevalue)
106+
const nodevalue_lazy = lazy_style(nodevalue)
110107
function nodevalue_lazy(a)
111108
if !iscall(a)
112109
return unwrap(a)
@@ -115,11 +112,11 @@ function nodevalue_lazy(a)
115112
end
116113
end
117114
using Base.Broadcast: materialize
118-
const materialize_lazy = lazy_interface(materialize)
115+
const materialize_lazy = lazy_style(materialize)
119116
materialize_lazy(a) = argwalk(unwrap, a)
120-
const copy_lazy = lazy_interface(copy)
117+
const copy_lazy = lazy_style(copy)
121118
copy_lazy(a) = materialize(a)
122-
const equals_lazy = lazy_interface(==)
119+
const equals_lazy = lazy_style(==)
123120
function equals_lazy(a1, a2)
124121
u1, u2 = unwrap.((a1, a2))
125122
if !iscall(u1) && !iscall(u2)
@@ -130,7 +127,7 @@ function equals_lazy(a1, a2)
130127
return false
131128
end
132129
end
133-
const isequal_lazy = lazy_interface(isequal)
130+
const isequal_lazy = lazy_style(isequal)
134131
function isequal_lazy(a1, a2)
135132
u1, u2 = unwrap.((a1, a2))
136133
if !iscall(u1) && !iscall(u2)
@@ -141,13 +138,13 @@ function isequal_lazy(a1, a2)
141138
return false
142139
end
143140
end
144-
const hash_lazy = lazy_interface(hash)
141+
const hash_lazy = lazy_style(hash)
145142
function hash_lazy(a, h::UInt64)
146143
h = hash(Symbol(unspecify_type_parameters(typeof(a))), h)
147144
# Use `_hash`, which defines a custom hash for NamedDimsArray.
148145
return _hash(unwrap(a), h)
149146
end
150-
const map_arguments_lazy = lazy_interface(map_arguments)
147+
const map_arguments_lazy = lazy_style(map_arguments)
151148
function map_arguments_lazy(f, a)
152149
u = unwrap(a)
153150
if !iscall(u)
@@ -159,21 +156,21 @@ function map_arguments_lazy(f, a)
159156
end
160157
end
161158
function substitute end
162-
const substitute_lazy = lazy_interface(substitute)
159+
const substitute_lazy = lazy_style(substitute)
163160
function substitute_lazy(a, substitutions::AbstractDict)
164161
haskey(substitutions, a) && return substitutions[a]
165162
!iscall(a) && return a
166163
return map_arguments(arg -> substitute(arg, substitutions), a)
167164
end
168165
substitute_lazy(a, substitutions) = substitute(a, Dict(substitutions))
169166
using AbstractTrees: printnode
170-
const printnode_lazy = lazy_interface(printnode)
167+
const printnode_lazy = lazy_style(printnode)
171168
function printnode_lazy(io, a)
172169
# Use `printnode_nameddims` to avoid type piracy,
173170
# since it overloads on `AbstractNamedDimsArray`.
174171
return printnode_nameddims(io, unwrap(a))
175172
end
176-
const show_lazy = lazy_interface(show)
173+
const show_lazy = lazy_style(show)
177174
function show_lazy(io::IO, a)
178175
if !iscall(a)
179176
return show(io, unwrap(a))
@@ -187,12 +184,12 @@ function show_lazy(io::IO, mime::MIME"text/plain", a)
187184
!iscall(a) ? show(io, mime, unwrap(a)) : show(io, a)
188185
return nothing
189186
end
190-
const add_lazy = lazy_interface(+)
187+
const add_lazy = lazy_style(+)
191188
add_lazy(a1, a2) = error("Not implemented.")
192-
const sub_lazy = lazy_interface(-)
189+
const sub_lazy = lazy_style(-)
193190
sub_lazy(a) = error("Not implemented.")
194191
sub_lazy(a1, a2) = error("Not implemented.")
195-
const mul_lazy = lazy_interface(*)
192+
const mul_lazy = lazy_style(*)
196193
function mul_lazy(a)
197194
u = unwrap(a)
198195
if !iscall(u)
@@ -216,7 +213,7 @@ mul_lazy(a1::Number, a2::Number) = a1 * a2
216213
div_lazy(a1, a2::Number) = error("Not implemented.")
217214

218215
# NamedDimsArrays.jl interface.
219-
const inds_lazy = lazy_interface(inds)
216+
const inds_lazy = lazy_style(inds)
220217
function inds_lazy(a)
221218
u = unwrap(a)
222219
if !iscall(u)
@@ -227,11 +224,11 @@ function inds_lazy(a)
227224
return error("Variant not supported.")
228225
end
229226
end
230-
const dename_lazy = lazy_interface(dename)
231-
function dename_lazy(a)
227+
const denamed_lazy = lazy_style(denamed)
228+
function denamed_lazy(a)
232229
u = unwrap(a)
233230
if !iscall(u)
234-
return dename(u)
231+
return denamed(u)
235232
else
236233
return error("Variant not supported.")
237234
end

src/LazyNamedDimsArrays/lazynameddimsarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ lazy(a::AbstractNamedDimsArray) = LazyNamedDimsArray(a)
2424
lazy(a::Mul{<:LazyNamedDimsArray}) = LazyNamedDimsArray(a)
2525

2626
NamedDimsArrays.inds(a::LazyNamedDimsArray) = inds_lazy(a)
27-
NamedDimsArrays.dename(a::LazyNamedDimsArray) = dename_lazy(a)
27+
NamedDimsArrays.denamed(a::LazyNamedDimsArray) = denamed_lazy(a)
2828

2929
# Broadcasting
3030
function Base.BroadcastStyle(::Type{<:LazyNamedDimsArray})

src/LazyNamedDimsArrays/nameddimsarraysextensions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
using NamedDimsArrays: NamedDimsArray, dename, inds
1+
using NamedDimsArrays: NamedDimsArray, denamed, inds
22
# Defined to avoid type piracy.
33
# TODO: Define a proper hash function
44
# in NamedDimsArrays.jl, maybe one that is
55
# independent of the order of dimensions.
66
function _hash(a::NamedDimsArray, h::UInt64)
77
h = hash(:NamedDimsArray, h)
8-
h = hash(dename(a), h)
8+
h = hash(denamed(a), h)
99
for i in inds(a)
1010
h = hash(i, h)
1111
end

src/LazyNamedDimsArrays/symbolicarray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ end
2828
function Base.setindex!(a::SymbolicArray{<:Any, N}, value, I::Vararg{Int, N}) where {N}
2929
return error("Indexing into SymbolicArray not supported.")
3030
end
31-
using DerivableInterfaces: DerivableInterfaces
32-
DerivableInterfaces.permuteddims(a::SymbolicArray, p) = permutedims(a, p)
31+
using FunctionImplementations: FunctionImplementations
32+
FunctionImplementations.permuteddims(a::SymbolicArray, p) = permutedims(a, p)
3333
function Base.permutedims(a::SymbolicArray, p)
3434
@assert ndims(a) == length(p) && isperm(p)
3535
return SymbolicArray(symname(a), ntuple(i -> axes(a)[p[i]], ndims(a)))

src/LazyNamedDimsArrays/symbolicnameddimsarray.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
1-
using NamedDimsArrays: NamedDimsArray, dename, inds, nameddims
1+
using NamedDimsArrays: NamedDimsArray, denamed, inds, nameddims
22

33
const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
44
NamedDimsArray{T, N, Parent, DimNames}
55
function symnameddims(name, dims)
6-
return lazy(nameddims(SymbolicArray(name, dename.(dims)), dims))
6+
return lazy(nameddims(SymbolicArray(name, denamed.(dims)), dims))
77
end
88
function symnameddims(name, ndarray::AbstractNamedDimsArray)
99
return symnameddims(name, Tuple(inds(ndarray)))
1010
end
1111
symnameddims(name) = symnameddims(name, ())
1212
using AbstractTrees: AbstractTrees
1313
function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray)
14-
print(io, symname(dename(a)))
14+
print(io, symname(denamed(a)))
1515
if ndims(a) > 0
1616
print(io, "[", join(dimnames(a), ","), "]")
1717
end
1818
return nothing
1919
end
2020
printnode_nameddims(io::IO, a::SymbolicNamedDimsArray) = AbstractTrees.printnode(io, a)
2121
function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
22-
return issetequal(inds(a), inds(b)) && dename(a) == dename(b)
22+
return issetequal(inds(a), inds(b)) && denamed(a) == denamed(b)
2323
end
2424
Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) = lazy(a) * lazy(b)
2525
Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) = lazy(a) * b

src/TensorNetworkGenerators/ising_network.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using DiagonalArrays: DiagonalArray
22
using Graphs: degree, dst, edges, src
33
using ..ITensorNetworksNext: @preserve_graph
44
using LinearAlgebra: Diagonal, eigen
5-
using NamedDimsArrays: apply, dename, inds, operator, randname
5+
using NamedDimsArrays: apply, denamed, inds, operator, randname
66
using NamedGraphs.GraphsExtensions: vertextype
77

88
function sqrt_ising_bond(β; J = one(β), h = zero(β), deg1::Integer, deg2::Integer)
@@ -33,7 +33,7 @@ function ising_network(
3333
(e) = get(() -> l̃[reverse(e)], l̃, e)
3434
tn = delta_network(f̃, elt, g)
3535
for v in sz_vertices
36-
a = DiagonalArray(elt[1, -1], dename.(inds(tn[v])))
36+
a = DiagonalArray(elt[1, -1], denamed.(inds(tn[v])))
3737
tn[v] = a[inds(tn[v])...]
3838
end
3939
for e in edges(tn)

0 commit comments

Comments
 (0)