@@ -30,8 +30,13 @@ lazy_function(::typeof(*)) = *ₗ
3030lazy_function (:: typeof (/ )) = / ₗ
3131lazy_function (:: typeof (\ )) = \ ₗ
3232lazy_function (:: typeof (conj)) = conjed
33+ lazy_function (:: typeof (identity)) = identity
34+ lazy_function (f:: Base.Fix1{typeof(*), <:Number} ) = Base. Fix1 (* ₗ, f. x)
35+ lazy_function (f:: Base.Fix2{typeof(*), <:Number} ) = Base. Fix2 (* ₗ, f. x)
36+ lazy_function (f:: Base.Fix2{typeof(/), <:Number} ) = Base. Fix2 (/ ₗ, f. x)
3337
3438broadcast_is_linear (f, args... ) = false
39+ broadcast_is_linear (:: typeof (identity), :: Base.AbstractArrayOrBroadcasted ) = true
3540broadcast_is_linear (:: typeof (+ ), :: Base.AbstractArrayOrBroadcasted... ) = true
3641broadcast_is_linear (:: typeof (- ), :: Base.AbstractArrayOrBroadcasted ) = true
3742function broadcast_is_linear (
@@ -50,13 +55,41 @@ function broadcast_is_linear(
5055end
5156broadcast_is_linear (:: typeof (* ), :: Number , :: Number ) = true
5257broadcast_is_linear (:: typeof (conj), :: Base.AbstractArrayOrBroadcasted ) = true
58+ function broadcast_is_linear (
59+ :: Base.Fix1{typeof(*), <:Number} , :: Base.AbstractArrayOrBroadcasted
60+ )
61+ return true
62+ end
63+ function broadcast_is_linear (
64+ :: Base.Fix2{typeof(*), <:Number} , :: Base.AbstractArrayOrBroadcasted
65+ )
66+ return true
67+ end
68+ function broadcast_is_linear (
69+ :: Base.Fix2{typeof(/), <:Number} , :: Base.AbstractArrayOrBroadcasted
70+ )
71+ return true
72+ end
5373is_linear (x) = true
5474function is_linear (bc:: BC.Broadcasted )
5575 return broadcast_is_linear (bc. f, bc. args... ) && all (is_linear, bc. args)
5676end
5777
5878to_linear (x) = x
5979to_linear (bc:: BC.Broadcasted ) = lazy_function (bc. f)(to_linear .(bc. args)... )
80+ function broadcast_error (style, f)
81+ return throw (
82+ ArgumentError (
83+ " Only linear broadcast operations are supported for `$style `, got `$f `."
84+ )
85+ )
86+ end
87+ function broadcasted_linear (style:: BC.BroadcastStyle , f, args... )
88+ bc = BC. Broadcasted (style, f, args)
89+ is_linear (bc) || broadcast_error (style, f)
90+ return to_linear (bc)
91+ end
92+ broadcasted_linear (f, args... ) = broadcasted_linear (BC. combine_styles (args... ), f, args... )
6093# TODO : Use `Broadcast.broadcastable` interface for this?
6194to_broadcasted (x) = x
6295function to_broadcasted (a:: AbstractArray )
@@ -136,6 +169,7 @@ similar_scaled(a::AbstractArray) = similar(unscaled(a))
136169similar_scaled (a:: AbstractArray , elt:: Type ) = similar (unscaled (a), elt)
137170similar_scaled (a:: AbstractArray , ax) = similar (unscaled (a), ax)
138171similar_scaled (a:: AbstractArray , elt:: Type , ax) = similar (unscaled (a), elt, ax)
172+ getindex_scaled (a:: AbstractArray , I... ) = coeff (a) * getindex (unscaled (a), I... )
139173copyto!_scaled (dest:: AbstractArray , src:: AbstractArray ) = add! (dest, src, true , false )
140174show_scaled (io:: IO , a:: AbstractArray ) = show_lazy (io, a)
141175show_scaled (io:: IO , mime:: MIME"text/plain" , a:: AbstractArray ) = show_lazy (io, mime, a)
@@ -227,6 +261,9 @@ macro scaledarray_base(ScaledArray, AbstractArray = :AbstractArray)
227261 function Base. similar (a:: $ScaledArray , elt:: Type , ax:: Dims )
228262 return $ TensorAlgebra. similar_scaled (a, elt, ax)
229263 end
264+ Base. @propagate_inbounds function Base. getindex (a:: $ScaledArray , I... )
265+ return $ TensorAlgebra. getindex_scaled (a, I... )
266+ end
230267 function Base. copyto! (dest:: $AbstractArray , src:: $ScaledArray )
231268 return $ TensorAlgebra. copyto!_scaled (dest, src)
232269 end
@@ -372,6 +409,7 @@ size_conj(a::AbstractArray) = size(conjed(a))
372409similar_conj (a:: AbstractArray , elt:: Type ) = similar (conjed (a), elt)
373410similar_conj (a:: AbstractArray , elt:: Type , ax) = similar (conjed (a), elt, ax)
374411similar_conj (a:: AbstractArray , ax) = similar (conjed (a), ax)
412+ getindex_conj (a:: AbstractArray , I... ) = conj (getindex (conjed (a), I... ))
375413copyto!_conj (dest:: AbstractArray , src:: AbstractArray ) = add! (dest, src, true , false )
376414show_conj (io:: IO , a:: AbstractArray ) = show_lazy (io, a)
377415show_conj (io:: IO , mime:: MIME"text/plain" , a:: AbstractArray ) = show_lazy (io, mime, a)
@@ -424,6 +462,9 @@ macro conjarray_base(ConjArray, AbstractArray = :AbstractArray)
424462 function Base. similar (a:: $ConjArray , elt:: Type , ax:: Dims )
425463 return $ TensorAlgebra. similar_conj (a, elt, ax)
426464 end
465+ Base. @propagate_inbounds function Base. getindex (a:: $ConjArray , I... )
466+ return $ TensorAlgebra. getindex_conj (a, I... )
467+ end
427468 function Base. copyto! (dest:: $AbstractArray , src:: $ConjArray )
428469 return $ TensorAlgebra. copyto!_conj (dest, src)
429470 end
@@ -525,6 +566,7 @@ similar_add(a::AbstractArray, elt::Type) = similar(BC.Broadcasted(+, addends(a))
525566function similar_add (a:: AbstractArray , elt:: Type , ax)
526567 return similar (BC. Broadcasted (+ , addends (a)), elt, ax)
527568end
569+ getindex_add (a:: AbstractArray , I... ) = sum (addend -> getindex (addend, I... ), addends (a))
528570copyto!_add (dest:: AbstractArray , src:: AbstractArray ) = add! (dest, src, true , false )
529571show_add (io:: IO , a:: AbstractArray ) = show_lazy (io, a)
530572show_add (io:: IO , mime:: MIME"text/plain" , a:: AbstractArray ) = show_lazy (io, mime, a)
@@ -611,6 +653,9 @@ macro addarray_base(AddArray, AbstractArray = :AbstractArray)
611653 function Base. similar (a:: $AddArray , elt:: Type , ax)
612654 return $ TensorAlgebra. similar_add (a, elt, ax)
613655 end
656+ Base. @propagate_inbounds function Base. getindex (a:: $AddArray , I... )
657+ return $ TensorAlgebra. getindex_add (a, I... )
658+ end
614659 function Base. copyto! (dest:: $AbstractArray , src:: $AddArray )
615660 return $ TensorAlgebra. copyto!_add (dest, src)
616661 end
@@ -741,6 +786,20 @@ similar_mul(a::AbstractArray, elt::Type) = similar(a, elt, axes(a))
741786# TODO : Make use of both arguments to determine the output, maybe
742787# using `LinearAlgebra.matprod_dest(factors(a)..., elt)`?
743788similar_mul (a:: AbstractArray , elt:: Type , ax) = similar (last (factors (a)), elt, ax)
789+ function mul_getindex (a1:: AbstractMatrix , a2:: AbstractMatrix , i:: Int , j:: Int )
790+ return transpose (view (a1, i, :)) * view (a2, :, j)
791+ end
792+ function mul_getindex (a1:: AbstractMatrix , a2:: AbstractVector , i:: Int )
793+ return transpose (view (a1, i, :)) * a2
794+ end
795+ function mul_getindex (a1:: AbstractVector , a2:: AbstractMatrix , j:: Int )
796+ return transpose (a1) * view (a2, :, j)
797+ end
798+ function getindex_mul (a:: AbstractArray , i:: Int )
799+ I = Tuple (CartesianIndices (axes (a))[i])
800+ return getindex_mul (a, I... )
801+ end
802+ getindex_mul (a:: AbstractArray , I:: Vararg{Int} ) = mul_getindex (factors (a)... , I... )
744803copyto!_mul (dest:: AbstractArray , src:: AbstractArray ) = add! (dest, src, true , false )
745804show_mul (io:: IO , a:: AbstractArray ) = show_lazy (io, a)
746805show_mul (io:: IO , mime:: MIME"text/plain" , a:: AbstractArray ) = show_lazy (io, mime, a)
@@ -798,6 +857,11 @@ macro mularray_type(MulArray, AbstractArray = :AbstractArray)
798857 )
799858end
800859
860+ function copy_permuteddims (a:: PermutedDimsArray{<:Any, 2, perm} ) where {perm}
861+ perm == (1 , 2 ) && return copy (parent (a))
862+ return copy (transpose (parent (a)))
863+ end
864+
801865macro mularray_base (MulArray, AbstractArray = :AbstractArray )
802866 return esc (
803867 quote
@@ -819,6 +883,9 @@ macro mularray_base(MulArray, AbstractArray = :AbstractArray)
819883 function Base. similar (a:: $MulArray , elt:: Type , ax:: Dims )
820884 return $ TensorAlgebra. similar_mul (a, elt, ax)
821885 end
886+ Base. @propagate_inbounds function Base. getindex (a:: $MulArray , I... )
887+ return $ TensorAlgebra. getindex_mul (a, I... )
888+ end
822889 function Base. copyto! (dest:: $AbstractArray , src:: $MulArray )
823890 return $ TensorAlgebra. copyto!_mul (dest, src)
824891 end
@@ -881,6 +948,9 @@ macro mularray_terminterface(MulArray, AbstractArray = :AbstractArray)
881948 $ TensorAlgebra. iscall (a:: $MulArray ) = $ TensorAlgebra. iscall_mul (a)
882949 $ TensorAlgebra. operation (a:: $MulArray ) = $ TensorAlgebra. operation_mul (a)
883950 $ TensorAlgebra. arguments (a:: $MulArray ) = $ TensorAlgebra. arguments_mul (a)
951+ function Base. copy (a:: PermutedDimsArray{<:Any, 2, <:Any, <:Any, $MulArray} )
952+ return $ TensorAlgebra. copy_permuteddims (a)
953+ end
884954 end
885955 )
886956end
0 commit comments