@@ -4,13 +4,21 @@ using BlockArrays: AbstractBlockArray, Block, blocklength, blocks, eachblockaxes
44using BlockSparseArrays: AbstractBlockSparseArray, AbstractBlockSparseMatrix,
55 BlockUnitRange, blockrange, blocksparse
66using SparseArraysBase: eachstoredindex
7- using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, matricize, matricize_axes,
7+ using TensorAlgebra: TensorAlgebra, BlockedTuple, matricize, matricize_axes,
88 tensor_product_axis, unmatricize
99
10- const BlockReshapeFusion = typeof (FusionStyle (AbstractBlockArray))
10+ # TODO : Ideally we would use:
11+ # ```julia
12+ # const BlockReshapeFusion = typeof(TensorAlgebra.FusionStyle(AbstractBlockArray))
13+ # ```
14+ # but that doesn't seem to work, i.e. it sometimes returns `ReshapeFusion`. Maybe it is
15+ # a world age issue, though note that `@invokelatest` doesn't seem to fix it.
16+ # For now we just use `Base.get_extension`.
17+ const BlockReshapeFusion =
18+ Base. get_extension (TensorAlgebra, :TensorAlgebraBlockArraysExt ). BlockReshapeFusion
1119
1220function TensorAlgebra. tensor_product_axis (
13- style:: BlockReshapeFusion , side:: Val{:codomain} , r1:: BlockUnitRange , r2:: BlockUnitRange
21+ style:: BlockReshapeFusion , side:: Val{:codomain} , r1:: BlockUnitRange , r2:: BlockUnitRange ,
1422 )
1523 return tensor_product_blockrange (style, side, r1, r2)
1624end
0 commit comments