Skip to content

Commit a9f68a9

Browse files
authored
broadcasted_linear, getindex on lazy arrays (#153)
1 parent 8db3991 commit a9f68a9

3 files changed

Lines changed: 100 additions & 4 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
3-
version = "0.7.19"
3+
version = "0.7.20"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]

src/lazyarrays.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ lazy_function(::typeof(*)) = *ₗ
3030
lazy_function(::typeof(/)) = /
3131
lazy_function(::typeof(\)) = \
3232
lazy_function(::typeof(conj)) = conjed
33+
lazy_function(::typeof(identity)) = identity
34+
lazy_function(f::Base.Fix1{typeof(*), <:Number}) = Base.Fix1(*ₗ, f.x)
35+
lazy_function(f::Base.Fix2{typeof(*), <:Number}) = Base.Fix2(*ₗ, f.x)
36+
lazy_function(f::Base.Fix2{typeof(/), <:Number}) = Base.Fix2(/ₗ, f.x)
3337

3438
broadcast_is_linear(f, args...) = false
39+
broadcast_is_linear(::typeof(identity), ::Base.AbstractArrayOrBroadcasted) = true
3540
broadcast_is_linear(::typeof(+), ::Base.AbstractArrayOrBroadcasted...) = true
3641
broadcast_is_linear(::typeof(-), ::Base.AbstractArrayOrBroadcasted) = true
3742
function broadcast_is_linear(
@@ -50,13 +55,41 @@ function broadcast_is_linear(
5055
end
5156
broadcast_is_linear(::typeof(*), ::Number, ::Number) = true
5257
broadcast_is_linear(::typeof(conj), ::Base.AbstractArrayOrBroadcasted) = true
58+
function broadcast_is_linear(
59+
::Base.Fix1{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted
60+
)
61+
return true
62+
end
63+
function broadcast_is_linear(
64+
::Base.Fix2{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted
65+
)
66+
return true
67+
end
68+
function broadcast_is_linear(
69+
::Base.Fix2{typeof(/), <:Number}, ::Base.AbstractArrayOrBroadcasted
70+
)
71+
return true
72+
end
5373
is_linear(x) = true
5474
function is_linear(bc::BC.Broadcasted)
5575
return broadcast_is_linear(bc.f, bc.args...) && all(is_linear, bc.args)
5676
end
5777

5878
to_linear(x) = x
5979
to_linear(bc::BC.Broadcasted) = lazy_function(bc.f)(to_linear.(bc.args)...)
80+
function broadcast_error(style, f)
81+
return throw(
82+
ArgumentError(
83+
"Only linear broadcast operations are supported for `$style`, got `$f`."
84+
)
85+
)
86+
end
87+
function broadcasted_linear(style::BC.BroadcastStyle, f, args...)
88+
bc = BC.Broadcasted(style, f, args)
89+
is_linear(bc) || broadcast_error(style, f)
90+
return to_linear(bc)
91+
end
92+
broadcasted_linear(f, args...) = broadcasted_linear(BC.combine_styles(args...), f, args...)
6093
# TODO: Use `Broadcast.broadcastable` interface for this?
6194
to_broadcasted(x) = x
6295
function to_broadcasted(a::AbstractArray)
@@ -136,6 +169,7 @@ similar_scaled(a::AbstractArray) = similar(unscaled(a))
136169
similar_scaled(a::AbstractArray, elt::Type) = similar(unscaled(a), elt)
137170
similar_scaled(a::AbstractArray, ax) = similar(unscaled(a), ax)
138171
similar_scaled(a::AbstractArray, elt::Type, ax) = similar(unscaled(a), elt, ax)
172+
getindex_scaled(a::AbstractArray, I...) = coeff(a) * getindex(unscaled(a), I...)
139173
copyto!_scaled(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false)
140174
show_scaled(io::IO, a::AbstractArray) = show_lazy(io, a)
141175
show_scaled(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a)
@@ -227,6 +261,9 @@ macro scaledarray_base(ScaledArray, AbstractArray = :AbstractArray)
227261
function Base.similar(a::$ScaledArray, elt::Type, ax::Dims)
228262
return $TensorAlgebra.similar_scaled(a, elt, ax)
229263
end
264+
Base.@propagate_inbounds function Base.getindex(a::$ScaledArray, I...)
265+
return $TensorAlgebra.getindex_scaled(a, I...)
266+
end
230267
function Base.copyto!(dest::$AbstractArray, src::$ScaledArray)
231268
return $TensorAlgebra.copyto!_scaled(dest, src)
232269
end
@@ -372,6 +409,7 @@ size_conj(a::AbstractArray) = size(conjed(a))
372409
similar_conj(a::AbstractArray, elt::Type) = similar(conjed(a), elt)
373410
similar_conj(a::AbstractArray, elt::Type, ax) = similar(conjed(a), elt, ax)
374411
similar_conj(a::AbstractArray, ax) = similar(conjed(a), ax)
412+
getindex_conj(a::AbstractArray, I...) = conj(getindex(conjed(a), I...))
375413
copyto!_conj(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false)
376414
show_conj(io::IO, a::AbstractArray) = show_lazy(io, a)
377415
show_conj(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a)
@@ -424,6 +462,9 @@ macro conjarray_base(ConjArray, AbstractArray = :AbstractArray)
424462
function Base.similar(a::$ConjArray, elt::Type, ax::Dims)
425463
return $TensorAlgebra.similar_conj(a, elt, ax)
426464
end
465+
Base.@propagate_inbounds function Base.getindex(a::$ConjArray, I...)
466+
return $TensorAlgebra.getindex_conj(a, I...)
467+
end
427468
function Base.copyto!(dest::$AbstractArray, src::$ConjArray)
428469
return $TensorAlgebra.copyto!_conj(dest, src)
429470
end
@@ -525,6 +566,7 @@ similar_add(a::AbstractArray, elt::Type) = similar(BC.Broadcasted(+, addends(a))
525566
function similar_add(a::AbstractArray, elt::Type, ax)
526567
return similar(BC.Broadcasted(+, addends(a)), elt, ax)
527568
end
569+
getindex_add(a::AbstractArray, I...) = sum(addend -> getindex(addend, I...), addends(a))
528570
copyto!_add(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false)
529571
show_add(io::IO, a::AbstractArray) = show_lazy(io, a)
530572
show_add(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a)
@@ -611,6 +653,9 @@ macro addarray_base(AddArray, AbstractArray = :AbstractArray)
611653
function Base.similar(a::$AddArray, elt::Type, ax)
612654
return $TensorAlgebra.similar_add(a, elt, ax)
613655
end
656+
Base.@propagate_inbounds function Base.getindex(a::$AddArray, I...)
657+
return $TensorAlgebra.getindex_add(a, I...)
658+
end
614659
function Base.copyto!(dest::$AbstractArray, src::$AddArray)
615660
return $TensorAlgebra.copyto!_add(dest, src)
616661
end
@@ -741,6 +786,20 @@ similar_mul(a::AbstractArray, elt::Type) = similar(a, elt, axes(a))
741786
# TODO: Make use of both arguments to determine the output, maybe
742787
# using `LinearAlgebra.matprod_dest(factors(a)..., elt)`?
743788
similar_mul(a::AbstractArray, elt::Type, ax) = similar(last(factors(a)), elt, ax)
789+
function mul_getindex(a1::AbstractMatrix, a2::AbstractMatrix, i::Int, j::Int)
790+
return transpose(view(a1, i, :)) * view(a2, :, j)
791+
end
792+
function mul_getindex(a1::AbstractMatrix, a2::AbstractVector, i::Int)
793+
return transpose(view(a1, i, :)) * a2
794+
end
795+
function mul_getindex(a1::AbstractVector, a2::AbstractMatrix, j::Int)
796+
return transpose(a1) * view(a2, :, j)
797+
end
798+
function getindex_mul(a::AbstractArray, i::Int)
799+
I = Tuple(CartesianIndices(axes(a))[i])
800+
return getindex_mul(a, I...)
801+
end
802+
getindex_mul(a::AbstractArray, I::Vararg{Int}) = mul_getindex(factors(a)..., I...)
744803
copyto!_mul(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false)
745804
show_mul(io::IO, a::AbstractArray) = show_lazy(io, a)
746805
show_mul(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a)
@@ -798,6 +857,11 @@ macro mularray_type(MulArray, AbstractArray = :AbstractArray)
798857
)
799858
end
800859

860+
function copy_permuteddims(a::PermutedDimsArray{<:Any, 2, perm}) where {perm}
861+
perm == (1, 2) && return copy(parent(a))
862+
return copy(transpose(parent(a)))
863+
end
864+
801865
macro mularray_base(MulArray, AbstractArray = :AbstractArray)
802866
return esc(
803867
quote
@@ -819,6 +883,9 @@ macro mularray_base(MulArray, AbstractArray = :AbstractArray)
819883
function Base.similar(a::$MulArray, elt::Type, ax::Dims)
820884
return $TensorAlgebra.similar_mul(a, elt, ax)
821885
end
886+
Base.@propagate_inbounds function Base.getindex(a::$MulArray, I...)
887+
return $TensorAlgebra.getindex_mul(a, I...)
888+
end
822889
function Base.copyto!(dest::$AbstractArray, src::$MulArray)
823890
return $TensorAlgebra.copyto!_mul(dest, src)
824891
end
@@ -881,6 +948,9 @@ macro mularray_terminterface(MulArray, AbstractArray = :AbstractArray)
881948
$TensorAlgebra.iscall(a::$MulArray) = $TensorAlgebra.iscall_mul(a)
882949
$TensorAlgebra.operation(a::$MulArray) = $TensorAlgebra.operation_mul(a)
883950
$TensorAlgebra.arguments(a::$MulArray) = $TensorAlgebra.arguments_mul(a)
951+
function Base.copy(a::PermutedDimsArray{<:Any, 2, <:Any, <:Any, $MulArray})
952+
return $TensorAlgebra.copy_permuteddims(a)
953+
end
884954
end
885955
)
886956
end

test/test_lazy.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import FunctionImplementations as FI
2-
using TensorAlgebra: TensorAlgebra as TA, *ₗ, +ₗ, conjed
3-
using Test: @test, @test_broken, @testset
2+
using Base.Broadcast: Broadcast as BC
3+
using TensorAlgebra: TensorAlgebra as TA, *ₗ, +ₗ, /ₗ, conjed
4+
using Test: @test, @test_broken, @test_throws, @testset
45

56
@testset "lazy arrays" begin
67
@testset "lazy array operations" begin
@@ -92,6 +93,31 @@ using Test: @test, @test_broken, @testset
9293

9394
x = FI.permuteddims(a *ₗ b, perm)
9495
@test x PermutedDimsArray(a *ₗ b, perm)
95-
@test_broken copy(x) permutedims(a * b, perm)
96+
@test copy(x) permutedims(a * b, perm)
97+
end
98+
@testset "linear broadcast lowering" begin
99+
a = randn(ComplexF64, 2, 2)
100+
style = BC.DefaultArrayStyle{2}()
101+
102+
@test TA.broadcasted_linear(identity, a) a
103+
@test TA.broadcasted_linear(Base.Fix1(*, 2), a) 2 *ₗ a
104+
@test TA.broadcasted_linear(Base.Fix2(*, 2), a) a *2
105+
@test TA.broadcasted_linear(Base.Fix2(/, 2), a) a /2
106+
@test TA.broadcasted_linear(style, identity, a) a
107+
@test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) 2 *ₗ a
108+
@test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) a *2
109+
@test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) a /2
110+
@test TA.broadcasted_linear(style, conj, a) conjed(a)
111+
@test_throws ArgumentError TA.broadcasted_linear(style, exp, a)
112+
end
113+
@testset "scalar getindex" begin
114+
a = randn(ComplexF64, 2, 2)
115+
b = randn(ComplexF64, 2, 2)
116+
117+
@test (2 *ₗ a)[1, 2] == 2 * a[1, 2]
118+
@test conjed(a)[2, 1] == conj(a[2, 1])
119+
@test (a +ₗ b)[2, 2] == a[2, 2] + b[2, 2]
120+
@test (a *ₗ b)[1, 2] (a * b)[1, 2]
121+
@test (a *ₗ b)[3] (a * b)[3]
96122
end
97123
end

0 commit comments

Comments
 (0)