1- using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, BlockedArray, blockedrange,
2- eachblockaxes1, mortar
1+ using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, Block, BlockedArray,
2+ blockedrange, blocklength, blocks, eachblockaxes1, mortar
3+ using TensorAlgebra: TensorAlgebra, AbstractBlockTuple, BlockedTuple, FusionStyle,
4+ ReshapeFusion, matricize, matricize_axes, tensor_product_axis, unmatricize
35
46struct BlockReshapeFusion <: FusionStyle end
5- FusionStyle (:: Type{<:AbstractBlockArray} ) = BlockReshapeFusion ()
7+ TensorAlgebra . FusionStyle (:: Type{<:AbstractBlockArray} ) = BlockReshapeFusion ()
68
7- function trivial_axis (
9+ function TensorAlgebra . trivial_axis (
810 style:: BlockReshapeFusion , side:: Val{:codomain} , a:: AbstractArray ,
911 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
1012 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
@@ -16,7 +18,7 @@ function mortar_axis(axs)
1618 throw (ArgumentError (" Only one-based axes are supported" ))
1719 return blockedrange (length .(axs))
1820end
19- function tensor_product_axis (
21+ function TensorAlgebra . tensor_product_axis (
2022 style:: BlockReshapeFusion , side:: Val{:codomain} ,
2123 r1:: AbstractUnitRange , r2:: AbstractUnitRange ,
2224 )
@@ -26,7 +28,9 @@ function tensor_product_axis(
2628 blockaxs = vec (map (splat (tensor_product_axis), blockaxpairs))
2729 return mortar_axis (blockaxs)
2830end
29- function matricize (style:: BlockReshapeFusion , a:: AbstractArray , ndims_codomain:: Val )
31+ function TensorAlgebra. matricize (
32+ style:: BlockReshapeFusion , a:: AbstractArray , ndims_codomain:: Val
33+ )
3034 ax = matricize_axes (style, a, ndims_codomain)
3135 reshaped_blocks_a = reshape (blocks (a), blocklength .(ax))
3236 bs = map (reshaped_blocks_a) do b
@@ -35,7 +39,7 @@ function matricize(style::BlockReshapeFusion, a::AbstractArray, ndims_codomain::
3539 return mortar (bs, ax)
3640end
3741using BlockArrays: blocklengths
38- function unmatricize (
42+ function TensorAlgebra . unmatricize (
3943 :: BlockReshapeFusion , m:: AbstractMatrix ,
4044 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
4145 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
@@ -54,11 +58,11 @@ function unmatricize(
5458 return mortar (bs, ax)
5559end
5660
57- FusionStyle (:: Type{<:BlockedArray} ) = ReshapeFusion ()
61+ TensorAlgebra . FusionStyle (:: Type{<:BlockedArray} ) = ReshapeFusion ()
5862unblock (a:: BlockedArray ) = a. blocks
5963unblock (a:: AbstractBlockArray ) = a[Base. OneTo .(size (a))... ]
6064unblock (a:: AbstractArray ) = a
61- function matricize (:: ReshapeFusion , a:: BlockedArray , ndims_codomain:: Val )
65+ function TensorAlgebra . matricize (:: ReshapeFusion , a:: BlockedArray , ndims_codomain:: Val )
6266 return matricize (ReshapeFusion (), unblock (a), ndims_codomain)
6367end
6468function unmatricize_blocked (
@@ -72,21 +76,21 @@ function unmatricize_blocked(
7276 )
7377 return BlockedArray (a, (axes_codomain... , axes_domain... ))
7478end
75- function unmatricize (
79+ function TensorAlgebra . unmatricize (
7680 style:: ReshapeFusion , m:: AbstractMatrix ,
7781 axes_codomain:: Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}} ,
7882 axes_domain:: Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}} ,
7983 )
8084 return unmatricize_blocked (style, m, axes_codomain, axes_domain)
8185end
82- function unmatricize (
86+ function TensorAlgebra . unmatricize (
8387 style:: ReshapeFusion , m:: AbstractMatrix ,
8488 axes_codomain:: Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}} ,
8589 axes_domain:: Tuple{Vararg{AbstractBlockedUnitRange}} ,
8690 )
8791 return unmatricize_blocked (style, m, axes_codomain, axes_domain)
8892end
89- function unmatricize (
93+ function TensorAlgebra . unmatricize (
9094 style:: ReshapeFusion , m:: AbstractMatrix ,
9195 axes_codomain:: Tuple{Vararg{AbstractBlockedUnitRange}} ,
9296 axes_domain:: Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}} ,
0 commit comments