Skip to content

Commit 7b3bd07

Browse files
feat: add GPU-native kron support for Diagonal matrices (#690)
1 parent 88dcf9c commit 7b3bd07

2 files changed

Lines changed: 92 additions & 0 deletions

File tree

src/host/linalg.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,3 +952,57 @@ for wrapa in trans_adj_wrappers, wrapb in trans_adj_wrappers
952952
return kron!(C, A, B)
953953
end
954954
end
955+
956+
@kernel function kron_diag_dense_kernel!(C, @Const(a), @Const(B))
957+
ci, cj = @index(Global, NTuple)
958+
mb = size(B, 1)
959+
nb = size(B, 2)
960+
i = fld1(ci, mb)
961+
bi = mod1(ci, mb)
962+
j = fld1(cj, nb)
963+
bj = mod1(cj, nb)
964+
@inbounds C[ci, cj] = (i == j) ? a[i] * B[bi, bj] : zero(eltype(C))
965+
end
966+
967+
@kernel function kron_dense_diag_kernel!(C, @Const(A), @Const(b))
968+
ci, cj = @index(Global, NTuple)
969+
nb = length(b)
970+
i = fld1(ci, nb)
971+
bi = mod1(ci, nb)
972+
j = fld1(cj, nb)
973+
bj = mod1(cj, nb)
974+
@inbounds C[ci, cj] = (bi == bj) ? A[i, j] * b[bi] : zero(eltype(C))
975+
end
976+
977+
function LinearAlgebra.kron!(C::AbstractGPUMatrix, A::Diagonal{T1, <:AbstractGPUVector}, B::AbstractGPUMatrix{T2}) where {T1, T2}
978+
size(C) == (length(A.diag) * size(B, 1), length(A.diag) * size(B, 2)) || throw(DimensionMismatch())
979+
backend = KernelAbstractions.get_backend(C)
980+
kron_diag_dense_kernel!(backend)(C, A.diag, B, ndrange = size(C))
981+
return C
982+
end
983+
984+
function LinearAlgebra.kron(A::Diagonal{T1, <:AbstractGPUVector}, B::AbstractGPUMatrix{T2}) where {T1, T2}
985+
T = promote_type(T1, T2)
986+
return kron!(similar(B, T, length(A.diag) * size(B, 1), length(A.diag) * size(B, 2)), A, B)
987+
end
988+
989+
function LinearAlgebra.kron!(C::AbstractGPUMatrix, A::AbstractGPUMatrix{T1}, B::Diagonal{T2, <:AbstractGPUVector}) where {T1, T2}
990+
size(C) == (size(A, 1) * length(B.diag), size(A, 2) * length(B.diag)) || throw(DimensionMismatch())
991+
backend = KernelAbstractions.get_backend(C)
992+
kron_dense_diag_kernel!(backend)(C, A, B.diag, ndrange = size(C))
993+
return C
994+
end
995+
996+
function LinearAlgebra.kron(A::AbstractGPUMatrix{T1}, B::Diagonal{T2, <:AbstractGPUVector}) where {T1, T2}
997+
T = promote_type(T1, T2)
998+
return kron!(similar(A, T, size(A, 1) * length(B.diag), size(A, 2) * length(B.diag)), A, B)
999+
end
1000+
1001+
function LinearAlgebra.kron!(C::Diagonal{<:Any, <:AbstractGPUVector}, A::Diagonal{T1, <:AbstractGPUVector}, B::Diagonal{T2, <:AbstractGPUVector}) where {T1, T2}
1002+
kron!(C.diag, A.diag, B.diag)
1003+
return C
1004+
end
1005+
1006+
function LinearAlgebra.kron(A::Diagonal{T1, <:AbstractGPUVector}, B::Diagonal{T2, <:AbstractGPUVector}) where {T1, T2}
1007+
Diagonal(kron(A.diag, B.diag))
1008+
end

test/testsuite/linalg.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,3 +567,41 @@ end
567567
end
568568
end
569569
end
570+
571+
@testsuite "linalg/kron_diagonal" (AT, eltypes) -> begin
572+
for T in filter(T -> T == Float32 || T == Float64, eltypes)
573+
n, m = 16, 8
574+
a, b = rand(T, n), rand(T, m)
575+
576+
# Diagonal*Diagonal
577+
R = kron(Diagonal(adapt(AT, a)), Diagonal(adapt(AT, b)))
578+
@test R isa Diagonal
579+
@test Array(R.diag) kron(a, b)
580+
581+
# Diagonal*Dense
582+
B = rand(T, m, m)
583+
R2 = kron(Diagonal(adapt(AT, a)), adapt(AT, B))
584+
@test Array(R2) kron(Matrix(Diagonal(a)), B)
585+
586+
# Dense*Diagonal
587+
A = rand(T, n, n)
588+
R3 = kron(adapt(AT, A), Diagonal(adapt(AT, b)))
589+
@test Array(R3) kron(A, Matrix(Diagonal(b)))
590+
591+
# kron! Diagonal*Diagonal
592+
C1 = Diagonal(adapt(AT, zeros(T, n * m)))
593+
kron!(C1, Diagonal(adapt(AT, a)), Diagonal(adapt(AT, b)))
594+
@test C1 isa Diagonal
595+
@test Array(C1.diag) kron(a, b)
596+
597+
# kron! Diagonal*Dense
598+
C2 = adapt(AT, zeros(T, n * m, n * m))
599+
kron!(C2, Diagonal(adapt(AT, a)), adapt(AT, B))
600+
@test Array(C2) kron(Matrix(Diagonal(a)), B)
601+
602+
# kron! Dense*Diagonal
603+
C3 = adapt(AT, zeros(T, n * m, n * m))
604+
kron!(C3, adapt(AT, A), Diagonal(adapt(AT, b)))
605+
@test Array(C3) kron(A, Matrix(Diagonal(b)))
606+
end
607+
end

0 commit comments

Comments
 (0)