diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 6d33a22e7..909733566 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -8,6 +8,7 @@ using Distributed using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle, @allowscalar using IrrationalConstants: logtwo, logten using LinearAlgebra +using LinearAlgebra: AdjOrTrans using LinearAlgebra.BLAS using Random using RealDot: realdot @@ -62,6 +63,7 @@ include("rulesets/LinearAlgebra/factorization.jl") include("rulesets/LinearAlgebra/uniformscaling.jl") include("rulesets/SparseArrays/sparsematrix.jl") +include("rulesets/SparseArrays/symmetric.jl") include("rulesets/Random/random.jl") diff --git a/src/rulesets/SparseArrays/symmetric.jl b/src/rulesets/SparseArrays/symmetric.jl new file mode 100644 index 000000000..dad92bb69 --- /dev/null +++ b/src/rulesets/SparseArrays/symmetric.jl @@ -0,0 +1,292 @@ +##### +##### Hermitian/Symmetric sparse matrices +##### + +const HermSparse{T, I} = Hermitian{T, SparseMatrixCSC{T, I}} +const SymSparse{T, I} = Symmetric{T, SparseMatrixCSC{T, I}} +const HermOrSymSparse{T, I} = Union{HermSparse{T, I}, SymSparse{T, I}} + +const DenseMat{T} = Union{StridedMatrix{T}, AdjOrTrans{T, <:StridedVecOrMat{T}}} +const DenseVecOrMat{T} = Union{DenseMat{T}, StridedVector{T}} + +function unwrap(A) + if A isa Adjoint + B = parent(A) + + if B isa Transpose + return (parent(B), Val(:N), Val(:C)) + else + return (B, Val(:T), Val(:C)) + end + elseif A isa Transpose + B = parent(A) + + if B isa Adjoint + return (parent(B), Val(:N), Val(:C)) + else + return (B, Val(:T), Val(:N)) + end + else + return (A, Val(:N), Val(:N)) + end +end + +##### +##### selupd! +##### + +# SELected UPDate: compute the selected low-rank update +# +# C ← α A Bᴴ + conj(α) B Aᴴ + β C +# +# The update is only applied to the structural nonzeros of C. +function selupd!(C::HermSparse, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) + selupd!(parent(C), C.uplo, A, adjoint(B), α, β) + selupd!(parent(C), C.uplo, B, adjoint(A), conj(α), 1) + return C +end + +# SELected UPDate: compute the selected low-rank update +# +# C ← α A Bᴴ + α conj(B) Aᵀ + β C +# +# The update is only applied to the structural nonzeros of C. +function selupd!(C::SymSparse, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) + selupd!(parent(C), C.uplo, A, adjoint(B), α, β) + selupd!(parent(C), C.uplo, adjoint(transpose(B)), transpose(A), α, 1) + return C +end + +# SELected UPDate: compute the selected low-rank update +# +# C ← α A B + β C +# +# The update is only applied to the structural nonzeros of C. +function selupd!(C::SparseMatrixCSC, uplo::Char, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) + AP, tA, cA = unwrap(A) + BP, tB, cB = unwrap(B) + return selupd_impl!(C, uplo, AP, BP, α, β, tA, cA, tB, cB) +end + +function selupd_impl!(C::SparseMatrixCSC, uplo::Char, A::AbstractVector, B::AbstractVector, α, β, ::Val{tA}, ::Val{cA}, ::Val{tB}, ::Val{cB}) where {tA, cA, tB, cB} + @assert size(C, 1) == size(C, 2) == length(A) == length(B) + + @inbounds for j in axes(C, 2) + Bj = cB === :C ? conj(B[j]) : B[j] + + for p in nzrange(C, j) + i = rowvals(C)[p] + + if (uplo == 'L' && i >= j) || (uplo == 'U' && i <= j) + Ai = cA === :C ? conj(A[i]) : A[i] + + if iszero(β) + nonzeros(C)[p] = α * Ai * Bj + else + nonzeros(C)[p] = β * nonzeros(C)[p] + α * Ai * Bj + end + end + end + end + + return C +end + +function selupd_impl!(C::SparseMatrixCSC, uplo::Char, A::AbstractMatrix, B::AbstractMatrix, α, β, tA::Val{TA}, cA::Val{CA}, tB::Val{TB}, cB::Val{CB}) where {TA, CA, TB, CB} + @assert size(C, 1) == size(C, 2) + + if TA === :N && TB === :N + @assert size(A, 1) == size(C, 1) + @assert size(B, 2) == size(C, 1) + @assert size(A, 2) == size(B, 1) + elseif TA === :N && TB !== :N + @assert size(A, 1) == size(C, 1) + @assert size(B, 1) == size(C, 1) + @assert size(A, 2) == size(B, 2) + elseif TA !== :N && TB === :N + @assert size(A, 2) == size(C, 1) + @assert size(B, 2) == size(C, 1) + @assert size(A, 1) == size(B, 1) + else + @assert size(A, 2) == size(C, 1) + @assert size(B, 1) == size(C, 1) + @assert size(A, 1) == size(B, 2) + end + + if TA === :N + rng = axes(A, 2) + else + rng = axes(A, 1) + end + + if iszero(β) + fill!(nonzeros(C), β) + else + rmul!(nonzeros(C), β) + end + + for k in rng + if TA === :N + Ak = view(A, :, k) + else + Ak = view(A, k, :) + end + + if TB === :N + Bk = view(B, k, :) + else + Bk = view(B, :, k) + end + + selupd_impl!(C, uplo, Ak, Bk, α, 1, tA, cA, tB, cB) + end + + return C +end + +##### +##### rrule implementations +##### + +function mul_rrule_impl(A::HermOrSymSparse, B::DenseVecOrMat, ΔC) + ΔB = A * ΔC + ΔA = if ΔC isa AbstractZero + ZeroTangent() + else + @thunk begin + ΔA = similar(A) + selupd!(ΔA, ΔC, B, 1 / 2, 0) + ΔA + end + end + return ΔA, ΔB +end + +function mul_rrule_impl(A::DenseMat, B::HermSparse, ΔC) + ΔA = ΔC * B + ΔB = if ΔC isa AbstractZero + ZeroTangent() + else + @thunk begin + ΔB = similar(B) + selupd!(ΔB, A', ΔC', 1 / 2, 0) + ΔB + end + end + return ΔA, ΔB +end + +function mul_rrule_impl(A::DenseMat, B::SymSparse, ΔC) + ΔA = ΔC * B + ΔB = if ΔC isa AbstractZero + ZeroTangent() + else + @thunk begin + ΔB = similar(B) + selupd!(ΔB, transpose(ΔC), transpose(A), 1 / 2, 0) + ΔB + end + end + return ΔA, ΔB +end + +function dot_rrule_impl(x::StridedVector, A::HermOrSymSparse, y::StridedVector, Ax::StridedVector, Ay::StridedVector, Δz) + Δx = @thunk Δz * Ay + Δy = @thunk Δz * Ax + + ΔA = if Δz isa AbstractZero + ZeroTangent() + else + @thunk begin + ΔA = similar(A) + selupd!(ΔA, x, y, Δz / 2, 0) + ΔA + end + end + + return Δx, ΔA, Δy +end + +##### +##### rrule helpers +##### + +function mul_rrule(A::HermOrSymSparse, B::DenseVecOrMat) + C = A * B + + function pullback(ΔC) + ΔA, ΔB = mul_rrule_impl(A, B, ΔC) + return NoTangent(), ΔA, ΔB + end + + return C, pullback ∘ unthunk +end + +function mul_rrule(A::DenseMat, B::HermOrSymSparse) + C = A * B + + function pullback(ΔC) + ΔA, ΔB = mul_rrule_impl(A, B, ΔC) + return NoTangent(), ΔA, ΔB + end + + return C, pullback ∘ unthunk +end + +function dot_rrule(x::StridedVector, A::HermOrSymSparse, y::StridedVector) + Ax = A * x + Ay = A * y + z = dot(x, Ay) + + function pullback(Δz) + Δx, ΔA, Δy = dot_rrule_impl(x, A, y, Ax, Ay, Δz) + return NoTangent(), Δx, ΔA, Δy + end + + return z, pullback ∘ unthunk +end + +##### +##### frule implementations +##### + +function mul_frule_impl(A, B, dA, dB) + return A * B, dA * B + A * dB +end + +function dot_frule_impl(x::StridedVector, A::HermOrSymSparse, y::StridedVector, dx, dA, dy) + return dot(x, A, y), dot(dx, A, y) + dot(x, A, dy) + dot(x, dA, y) +end + +##### +##### frule / rrule dispatches +##### + +for T in (HermSparse, SymSparse) + # A * X + @eval function ChainRulesCore.frule((_, dA, dX)::Tuple, ::typeof(*), A::$T, X::DenseVecOrMat) + return mul_frule_impl(A, X, dA, dX) + end + + @eval function ChainRulesCore.rrule(::typeof(*), A::$T, X::DenseVecOrMat) + return mul_rrule(A, X) + end + + # X * A + @eval function ChainRulesCore.frule((_, dX, dA)::Tuple, ::typeof(*), X::DenseMat, A::$T) + return mul_frule_impl(X, A, dX, dA) + end + + @eval function ChainRulesCore.rrule(::typeof(*), X::DenseMat, A::$T) + return mul_rrule(X, A) + end + + # dot(x, A, y) - vectors only, matching upstream ChainRules + @eval function ChainRulesCore.frule((_, dx, dA, dy)::Tuple, ::typeof(dot), x::StridedVector, A::$T, y::StridedVector) + return dot_frule_impl(x, A, y, dx, dA, dy) + end + + @eval function ChainRulesCore.rrule(::typeof(dot), x::StridedVector, A::$T, y::StridedVector) + return dot_rrule(x, A, y) + end +end diff --git a/test/rulesets/SparseArrays/symmetric.jl b/test/rulesets/SparseArrays/symmetric.jl new file mode 100644 index 000000000..e919e446b --- /dev/null +++ b/test/rulesets/SparseArrays/symmetric.jl @@ -0,0 +1,184 @@ +@testset "Hermitian/Symmetric sparse matrices" begin + n = 10 + k = 3 # number of columns for matrix multiplication + + # Helper to create a random sparse Hermitian/Symmetric matrix + function rand_hermsym_sparse(SymHerm, T, n, uplo; density=0.3) + A = sprand(T, n, n, density) + A = A + A' + return SymHerm(A, uplo) + end + + # Helper to create a random tangent with same sparsity pattern + function rand_tangent_sparse(A::Union{Hermitian, Symmetric}) + dA = similar(A) + rand!(nonzeros(parent(dA))) + return dA + end + + # Convert symmetric gradient to raw parent convention (double off-diagonals in stored triangle, zero non-stored) + function symgrad_to_parent(∂A) + P = copy(parent(∂A)) + + for j in axes(P, 2) + for p in SparseArrays.nzrange(P, j) + i = rowvals(P)[p] + + if (∂A.uplo == 'U' && i < j) || (∂A.uplo == 'L' && i > j) + nonzeros(P)[p] *= 2 + elseif (∂A.uplo == 'U' && i > j) || (∂A.uplo == 'L' && i < j) + nonzeros(P)[p] = 0 + end + end + end + + return P + end + + # Test rrule by pulling back through parent and comparing to FD on parent + function test_sparse_rrule(f, A, args...; fdm=_fdm, rtol=1e-6, atol=1e-9) + SymHerm = typeof(A).name.wrapper + uplo = Symbol(A.uplo) + + # Compute rrule + y, pb = rrule(f, A, args...) + dy = rand_tangent(y) + cotangents = pb(dy) + ∂A = unthunk(cotangents[2]) + + # Convert symmetric gradient to parent convention + ∂data = symgrad_to_parent(∂A) + + # FD on parent data + f_parent(data) = f(SymHerm(data, uplo), args...) + ∂data_fd = j′vp(fdm, f_parent, dy, parent(A))[1] + + @test ∂data ≈ ∂data_fd rtol=rtol atol=atol + end + + function test_sparse_rrule_right(f, X, A; fdm=_fdm, rtol=1e-6, atol=1e-9) + SymHerm = typeof(A).name.wrapper + uplo = Symbol(A.uplo) + + # Compute rrule + y, pb = rrule(f, X, A) + dy = rand_tangent(y) + cotangents = pb(dy) + ∂A = unthunk(cotangents[3]) + + # Convert symmetric gradient to parent convention + ∂data = symgrad_to_parent(∂A) + + # FD on parent data + f_parent(data) = f(X, SymHerm(data, uplo)) + ∂data_fd = j′vp(fdm, f_parent, dy, parent(A))[1] + + @test ∂data ≈ ∂data_fd rtol=rtol atol=atol + end + + function test_sparse_rrule_dot(A, x, y; fdm=_fdm, rtol=1e-6, atol=1e-9) + SymHerm = typeof(A).name.wrapper + uplo = Symbol(A.uplo) + + # Compute rrule + val, pb = rrule(dot, x, A, y) + dval = rand_tangent(val) + cotangents = pb(dval) + ∂A = unthunk(cotangents[3]) + + # Convert symmetric gradient to parent convention + ∂data = symgrad_to_parent(∂A) + + # FD on parent data + f_parent(data) = dot(x, SymHerm(data, uplo), y) + ∂data_fd = j′vp(fdm, f_parent, dval, parent(A))[1] + + @test ∂data ≈ ∂data_fd rtol=rtol atol=atol + end + + @testset "$(SymHerm){$T} * DenseVecOrMat, uplo=:$uplo" for + SymHerm in (Symmetric, Hermitian), + T in (Float64, ComplexF64), + uplo in (:U, :L) + + A = rand_hermsym_sparse(SymHerm, T, n, uplo) + dA = rand_tangent_sparse(A) + x = randn(T, n) + X = randn(T, n, k) + Xt = randn(T, k, n) # for transpose/adjoint tests + + @testset "A * x (vector)" begin + test_sparse_rrule(*, A, x) + test_frule(*, A ⊢ dA, x) + end + + @testset "A * X (matrix)" begin + test_sparse_rrule(*, A, X) + test_frule(*, A ⊢ dA, X) + end + + @testset "A * X' (adjoint matrix)" begin + test_sparse_rrule(*, A, Xt') + test_frule(*, A ⊢ dA, Xt') + end + + @testset "A * transpose(X) (transpose matrix)" begin + test_sparse_rrule(*, A, transpose(Xt)) + test_frule(*, A ⊢ dA, transpose(Xt)) + end + end + + @testset "DenseMat * $(SymHerm){$T}, uplo=:$uplo" for + SymHerm in (Symmetric, Hermitian), + T in (Float64, ComplexF64), + uplo in (:U, :L) + + A = rand_hermsym_sparse(SymHerm, T, n, uplo) + dA = rand_tangent_sparse(A) + X = randn(T, k, n) + Xt = randn(T, n, k) # for transpose/adjoint tests + x = randn(T, n) # for row vector tests + + @testset "X * A (matrix)" begin + test_sparse_rrule_right(*, X, A) + test_frule(*, X, A ⊢ dA) + end + + @testset "X' * A (adjoint matrix)" begin + test_sparse_rrule_right(*, Xt', A) + test_frule(*, Xt', A ⊢ dA) + end + + @testset "transpose(X) * A (transpose matrix)" begin + test_sparse_rrule_right(*, transpose(Xt), A) + test_frule(*, transpose(Xt), A ⊢ dA) + end + + @testset "x' * A (adjoint vector / row vector)" begin + test_sparse_rrule_right(*, x', A) + test_frule(*, x', A ⊢ dA) + end + + @testset "transpose(x) * A (transpose vector / row vector)" begin + test_sparse_rrule_right(*, transpose(x), A) + test_frule(*, transpose(x), A ⊢ dA) + end + end + + # dot(x, A, y) - vectors only, matching upstream ChainRules + @testset "dot(x, $(SymHerm){$T}, y), uplo=:$uplo" for + SymHerm in (Symmetric, Hermitian), + T in (Float64, ComplexF64), + uplo in (:U, :L) + + A = rand_hermsym_sparse(SymHerm, T, n, uplo) + dA = rand_tangent_sparse(A) + x = randn(T, n) + y = randn(T, n) + + @testset "dot(x, A, y)" begin + test_sparse_rrule_dot(A, x, y) + test_frule(dot, x, A ⊢ dA, y) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 768f7c208..228289e57 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -85,6 +85,7 @@ end println() include_test("rulesets/SparseArrays/sparsematrix.jl") + include_test("rulesets/SparseArrays/symmetric.jl") println()