Skip to content

Commit d8c4ed4

Browse files
committed
Fix tests
1 parent ae8f349 commit d8c4ed4

1 file changed

Lines changed: 8 additions & 13 deletions

File tree

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
module BlockSparseArraysTensorAlgebraExt
22

33
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
4-
using TensorAlgebra:
5-
TensorAlgebra,
6-
BlockedTrivialPermutation,
7-
BlockedTuple,
8-
FusionStyle,
9-
ReshapeFusion,
10-
fuseaxes
4+
using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, fuseaxes
115

126
struct BlockReshapeFusion <: FusionStyle end
137

@@ -20,12 +14,12 @@ using BlockSparseArrays: blocksparse
2014
using SparseArraysBase: eachstoredindex
2115
using TensorAlgebra: TensorAlgebra, matricize, unmatricize
2216
function TensorAlgebra.matricize(
23-
::BlockReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
17+
::BlockReshapeFusion, a::AbstractArray, length1::Val, length2::Val
2418
)
25-
ax = fuseaxes(axes(a), biperm)
19+
ax = fuseaxes(axes(a), length1, length2)
2620
reshaped_blocks_a = reshape(blocks(a), map(blocklength, ax))
2721
key(I) = Block(Tuple(I))
28-
value(I) = matricize(reshaped_blocks_a[I], biperm)
22+
value(I) = matricize(reshaped_blocks_a[I], length1, length2)
2923
Is = eachstoredindex(reshaped_blocks_a)
3024
bs = if isempty(Is)
3125
# Catch empty case and make sure the type is constrained properly.
@@ -45,16 +39,17 @@ using BlockArrays: blocklengths
4539
function TensorAlgebra.unmatricize(
4640
::BlockReshapeFusion,
4741
m::AbstractMatrix,
48-
blocked_ax::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}},
42+
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
43+
domain_axes::Tuple{Vararg{AbstractUnitRange}},
4944
)
50-
ax = Tuple(blocked_ax)
45+
ax = (codomain_axes..., domain_axes...)
5146
reshaped_blocks_m = reshape(blocks(m), map(blocklength, ax))
5247
function f(I)
5348
block_axes_I = BlockedTuple(
5449
map(ntuple(identity, length(ax))) do i
5550
return Base.axes1(ax[i][Block(I[i])])
5651
end,
57-
blocklengths(blocked_ax),
52+
(length(codomain_axes), length(domain_axes)),
5853
)
5954
return unmatricize(reshaped_blocks_m[I], block_axes_I)
6055
end

0 commit comments

Comments
 (0)