diff --git a/src/abstract.jl b/src/abstract.jl index 8c16694f..43f8415c 100644 --- a/src/abstract.jl +++ b/src/abstract.jl @@ -24,6 +24,8 @@ const LinearOperatorIndexType{I} = # import methods we overload import Base.eltype, Base.isreal, Base.size, Base.show +import Base.copy!, Base.copyto! +import Base.Broadcast: BroadcastStyle, DefaultArrayStyle, Broadcasted, broadcastable, materialize! import LinearAlgebra.Symmetric, LinearAlgebra.issymmetric, LinearAlgebra.Hermitian, LinearAlgebra.ishermitian @@ -44,13 +46,13 @@ other operators, with matrices and with scalars. Operators may be transposed and conjugate-transposed using the usual Julia syntax. """ mutable struct LinearOperator{T, S, I <: Integer, F, Ft, Fct} <: AbstractLinearOperator{T} - const nrow::I - const ncol::I - const symmetric::Bool - const hermitian::Bool - const prod!::F - const tprod!::Ft - const ctprod!::Fct + nrow::I + ncol::I + symmetric::Bool + hermitian::Bool + prod!::F + tprod!::Ft + ctprod!::Fct nprod::I ntprod::I nctprod::I @@ -183,6 +185,55 @@ storage_type(op::Adjoint) = storage_type(parent(op)) storage_type(op::Transpose) = storage_type(parent(op)) storage_type(op::Diagonal) = typeof(parent(op)) +broadcastable(op::AbstractLinearOperator) = Ref(op) +BroadcastStyle(::Type{<:AbstractLinearOperator}) = DefaultArrayStyle{0}() + +function copyto!( + dest::LinearOperator{T, S}, + src::LinearOperator{T, S}, +) where {T, S} + dest.nrow = src.nrow + dest.ncol = src.ncol + dest.symmetric = src.symmetric + dest.hermitian = src.hermitian + dest.prod! = src.prod! + dest.tprod! = src.tprod! + dest.ctprod! = src.ctprod! + dest.nprod = src.nprod + dest.ntprod = src.ntprod + dest.nctprod = src.nctprod + dest.Mv = src.Mv + dest.Mtu = src.Mtu + return dest +end + +function copyto!(dest::LinearOperator, src::LinearOperator) + throw( + LinearOperatorException( + "cannot update a LinearOperator in-place: incompatible element types " * + "$(eltype(dest)) and $(eltype(src)), or storage types " * + "$(typeof(dest.Mv)) and $(typeof(src.Mv)). " * + "Use assignment (`dest = src`) instead.", + ), + ) +end + +copy!(dest::LinearOperator, src::LinearOperator) = copyto!(dest, src) + +function materialize!(dest::LinearOperator, bc::Broadcasted{DefaultArrayStyle{0}}) + (bc.f === identity && length(bc.args) == 1) || + throw( + LinearOperatorException( + "only broadcast assignment of a single LinearOperator is supported (e.g., `op .= new_op`).", + ), + ) + src_arg = bc.args[1] + src = src_arg isa Ref ? src_arg[] : src_arg + src isa LinearOperator || + throw(LinearOperatorException("right-hand side of `op .= ...` must be a LinearOperator")) + return copyto!(dest, src) +end + """ reset!(op) diff --git a/test/test_linop.jl b/test/test_linop.jl index 6e4ad36f..72600166 100644 --- a/test/test_linop.jl +++ b/test/test_linop.jl @@ -111,6 +111,31 @@ function test_linop() @test Matrix(Hermitian(op6)) == (A6 + adjoint(A6)) / 2 end + @testset "In-place operator update" begin + A = rand(5, 5) + B = rand(5, 5) + x = rand(5) + + opA = LinearOperator(A) + opB = LinearOperator(B) + + copyto!(opA, opB) + @test opA * x ≈ B * x + + C = rand(5, 5) + opC = LinearOperator(C) + copy!(opA, opC) + @test opA * x ≈ C * x + + opA .= opB + @test opA * x ≈ B * x + + @test_throws LinearOperatorException opA .+= opB + + op_complex = LinearOperator(rand(ComplexF64, 5, 5)) + @test_throws LinearOperatorException copyto!(opA, op_complex) + end + @testset "Constructor with specified structure" begin v = simple_vector(Float64, nrow) A = simple_matrix(ComplexF64, nrow, nrow)