Skip to content

Commit 9e9fd3e

Browse files
authored
Alternative blocked permutation interface for contract (#120)
1 parent 2efd54e commit 9e9fd3e

21 files changed

Lines changed: 484 additions & 245 deletions

Project.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.6.14"
4+
version = "0.7.0"
55

66
[deps]
7-
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
8-
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
97
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
108
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
119
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -15,17 +13,18 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1513
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1614

1715
[weakdeps]
16+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1817
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1918
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2019
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
2120

2221
[extensions]
22+
TensorAlgebraBlockArraysExt = "BlockArrays"
2323
TensorAlgebraGPUArraysCoreExt = "GPUArraysCore"
2424
TensorAlgebraMooncakeExt = "Mooncake"
2525
TensorAlgebraTensorOperationsExt = "TensorOperations"
2626

2727
[compat]
28-
ArrayLayouts = "1.10.4"
2928
BlockArrays = "1.7.2"
3029
EllipsisNotation = "1.8"
3130
FunctionImplementations = "0.3.1, 0.4"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ TensorAlgebra = {path = ".."}
99
[compat]
1010
Documenter = "1.8.1"
1111
Literate = "2.20.1"
12-
TensorAlgebra = "0.6"
12+
TensorAlgebra = "0.7"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
55
TensorAlgebra = {path = ".."}
66

77
[compat]
8-
TensorAlgebra = "0.6"
8+
TensorAlgebra = "0.7"
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
module TensorAlgebraBlockArraysExt
2+
3+
include("blockarrays.jl")
4+
include("blockedtuple.jl")
5+
6+
end

src/blockarrays.jl renamed to ext/TensorAlgebraBlockArraysExt/blockarrays.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, BlockedArray, blockedrange,
2-
eachblockaxes1, mortar
1+
using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, Block, BlockedArray,
2+
blockedrange, blocklength, blocks, eachblockaxes1, mortar
3+
using TensorAlgebra: TensorAlgebra, AbstractBlockTuple, BlockedTuple, FusionStyle,
4+
ReshapeFusion, matricize, matricize_axes, tensor_product_axis, unmatricize
35

46
struct BlockReshapeFusion <: FusionStyle end
5-
FusionStyle(::Type{<:AbstractBlockArray}) = BlockReshapeFusion()
7+
TensorAlgebra.FusionStyle(::Type{<:AbstractBlockArray}) = BlockReshapeFusion()
68

7-
function trivial_axis(
9+
function TensorAlgebra.trivial_axis(
810
style::BlockReshapeFusion, side::Val{:codomain}, a::AbstractArray,
911
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
1012
axes_domain::Tuple{Vararg{AbstractUnitRange}},
@@ -16,7 +18,7 @@ function mortar_axis(axs)
1618
throw(ArgumentError("Only one-based axes are supported"))
1719
return blockedrange(length.(axs))
1820
end
19-
function tensor_product_axis(
21+
function TensorAlgebra.tensor_product_axis(
2022
style::BlockReshapeFusion, side::Val{:codomain},
2123
r1::AbstractUnitRange, r2::AbstractUnitRange,
2224
)
@@ -26,7 +28,9 @@ function tensor_product_axis(
2628
blockaxs = vec(map(splat(tensor_product_axis), blockaxpairs))
2729
return mortar_axis(blockaxs)
2830
end
29-
function matricize(style::BlockReshapeFusion, a::AbstractArray, ndims_codomain::Val)
31+
function TensorAlgebra.matricize(
32+
style::BlockReshapeFusion, a::AbstractArray, ndims_codomain::Val
33+
)
3034
ax = matricize_axes(style, a, ndims_codomain)
3135
reshaped_blocks_a = reshape(blocks(a), blocklength.(ax))
3236
bs = map(reshaped_blocks_a) do b
@@ -35,7 +39,7 @@ function matricize(style::BlockReshapeFusion, a::AbstractArray, ndims_codomain::
3539
return mortar(bs, ax)
3640
end
3741
using BlockArrays: blocklengths
38-
function unmatricize(
42+
function TensorAlgebra.unmatricize(
3943
::BlockReshapeFusion, m::AbstractMatrix,
4044
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
4145
axes_domain::Tuple{Vararg{AbstractUnitRange}},
@@ -54,11 +58,11 @@ function unmatricize(
5458
return mortar(bs, ax)
5559
end
5660

57-
FusionStyle(::Type{<:BlockedArray}) = ReshapeFusion()
61+
TensorAlgebra.FusionStyle(::Type{<:BlockedArray}) = ReshapeFusion()
5862
unblock(a::BlockedArray) = a.blocks
5963
unblock(a::AbstractBlockArray) = a[Base.OneTo.(size(a))...]
6064
unblock(a::AbstractArray) = a
61-
function matricize(::ReshapeFusion, a::BlockedArray, ndims_codomain::Val)
65+
function TensorAlgebra.matricize(::ReshapeFusion, a::BlockedArray, ndims_codomain::Val)
6266
return matricize(ReshapeFusion(), unblock(a), ndims_codomain)
6367
end
6468
function unmatricize_blocked(
@@ -72,21 +76,21 @@ function unmatricize_blocked(
7276
)
7377
return BlockedArray(a, (axes_codomain..., axes_domain...))
7478
end
75-
function unmatricize(
79+
function TensorAlgebra.unmatricize(
7680
style::ReshapeFusion, m::AbstractMatrix,
7781
axes_codomain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
7882
axes_domain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
7983
)
8084
return unmatricize_blocked(style, m, axes_codomain, axes_domain)
8185
end
82-
function unmatricize(
86+
function TensorAlgebra.unmatricize(
8387
style::ReshapeFusion, m::AbstractMatrix,
8488
axes_codomain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
8589
axes_domain::Tuple{Vararg{AbstractBlockedUnitRange}},
8690
)
8791
return unmatricize_blocked(style, m, axes_codomain, axes_domain)
8892
end
89-
function unmatricize(
93+
function TensorAlgebra.unmatricize(
9094
style::ReshapeFusion, m::AbstractMatrix,
9195
axes_codomain::Tuple{Vararg{AbstractBlockedUnitRange}},
9296
axes_domain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import BlockArrays as BA
2+
import TensorAlgebra as TA
3+
4+
BA.blockfirsts(bt::TA.AbstractBlockTuple) = TA.blockfirsts(bt)
5+
BA.blocklasts(bt::TA.AbstractBlockTuple) = TA.blocklasts(bt)
6+
BA.blocklength(bt::TA.AbstractBlockTuple) = TA.blocklength(bt)
7+
BA.blocklengths(bt::TA.AbstractBlockTuple) = TA.blocklengths(bt)
8+
BA.blocklengths(type::Type{<:TA.AbstractBlockTuple}) = TA.blocklengths(type)
9+
BA.blocks(bt::TA.AbstractBlockTuple) = TA.blocks(bt)
10+
11+
TA.Block(I::BA.Block) = TA.Block(I.n)
12+
TA.BlockRange(I::BA.BlockRange) = TA.BlockRange(I.indices)
13+
TA.BlockIndexRange(I::BA.BlockIndexRange) = TA.BlockIndexRange(TA.Block(I.block), I.indices)
14+
Base.:(==)(I::BA.Block, J::TA.Block) = I.n == J.n
15+
Base.:(==)(I::TA.Block, J::BA.Block) = I.n == J.n
16+
Base.getindex(bt::TA.BlockedTuple, I::BA.Block) = bt[TA.Block(I)]
17+
Base.getindex(bt::TA.AbstractBlockTuple, I::BA.BlockIndexRange) = bt[TA.BlockIndexRange(I)]
18+
Base.getindex(bt::TA.AbstractBlockTuple, I::BA.BlockRange{1}) = bt[TA.BlockRange(I)]
19+
20+
BA.blocklasts(r::TA.BlockedOneTo) = TA.blocklasts(r)

ext/TensorAlgebraMooncakeExt/TensorAlgebraMooncakeExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ Mooncake.tangent_type(::Type{<:AbstractBlockPermutation}) = Mooncake.NoTangent
99
Mooncake.tangent_type(::Type{<:ContractAlgorithm}) = Mooncake.NoTangent
1010

1111
@zero_derivative DefaultCtx Tuple{
12-
typeof(allocate_output), typeof(contract), Any, Any, Any, Any, Any,
12+
typeof(allocate_output), typeof(contract), Any, Any, Any, Any, Any, Any, Any, Any,
1313
}
1414
@zero_derivative DefaultCtx Tuple{typeof(biperm), Any, Any}
1515
@zero_derivative DefaultCtx Tuple{typeof(blockedperms), typeof(contract), Any, Any, Any}
16-
@zero_derivative DefaultCtx Tuple{typeof(check_input), typeof(contract), Any, Any, Any, Any}
1716
@zero_derivative DefaultCtx Tuple{
18-
typeof(check_input), typeof(contract!), Any, Any, Any, Any, Any, Any,
17+
typeof(check_input), typeof(contract), Any, Any, Any, Any, Any, Any,
18+
}
19+
@zero_derivative DefaultCtx Tuple{
20+
typeof(check_input), typeof(contract!), Any, Any, Any, Any, Any, Any, Any, Any, Any,
1921
}
2022
@zero_derivative DefaultCtx Tuple{typeof(contract_labels), Any, Any}
2123
@zero_derivative DefaultCtx Tuple{typeof(contract_labels), Any, Any, Any, Any}

0 commit comments

Comments
 (0)