Skip to content

Commit 74c77e6

Browse files
committed
Fix tests
1 parent e2e209e commit 74c77e6

1 file changed

Lines changed: 11 additions & 3 deletions

File tree

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,21 @@ using BlockArrays: AbstractBlockArray, Block, blocklength, blocks, eachblockaxes
44
using BlockSparseArrays: AbstractBlockSparseArray, AbstractBlockSparseMatrix,
55
BlockUnitRange, blockrange, blocksparse
66
using SparseArraysBase: eachstoredindex
7-
using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, matricize, matricize_axes,
7+
using TensorAlgebra: TensorAlgebra, BlockedTuple, matricize, matricize_axes,
88
tensor_product_axis, unmatricize
99

10-
const BlockReshapeFusion = typeof(FusionStyle(AbstractBlockArray))
10+
# TODO: Ideally we would use:
11+
# ```julia
12+
# const BlockReshapeFusion = typeof(TensorAlgebra.FusionStyle(AbstractBlockArray))
13+
# ```
14+
# but that doesn't seem to work, i.e. it sometimes returns `ReshapeFusion`. Maybe it is
15+
# a world age issue, though note that `@invokelatest` doesn't seem to fix it.
16+
# For now we just use `Base.get_extension`.
17+
const BlockReshapeFusion =
18+
Base.get_extension(TensorAlgebra, :TensorAlgebraBlockArraysExt).BlockReshapeFusion
1119

1220
function TensorAlgebra.tensor_product_axis(
13-
style::BlockReshapeFusion, side::Val{:codomain}, r1::BlockUnitRange, r2::BlockUnitRange
21+
style::BlockReshapeFusion, side::Val{:codomain}, r1::BlockUnitRange, r2::BlockUnitRange,
1422
)
1523
return tensor_product_blockrange(style, side, r1, r2)
1624
end

0 commit comments

Comments
 (0)