Skip to content

Commit 1966bfc

Browse files
authored
Add ShiftedOperator for shifted linear operators (#394)
Introduces the ShiftedOperator type to represent operators of the form H + σI, supporting efficient multiplication and adjoint/transpose operations. Includes comprehensive tests for real symmetric, complex non-Hermitian, and mutable shift scenarios. The new operator is integrated into the main module.
1 parent e5a09d6 commit 1966bfc

4 files changed

Lines changed: 233 additions & 0 deletions

File tree

src/LinearOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ include("DiagonalHessianApproximation.jl")
2222
include("linalg.jl")
2323
include("special-operators.jl")
2424
include("TimedOperators.jl")
25+
include("shifted_operators.jl")
2526

2627
# Utilities
2728
include("utilities.jl")

src/shifted_operators.jl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
export ShiftedOperator
2+
3+
"A data type to hold information relative to shifted operators."
4+
mutable struct ShiftedData{T, OpH}
5+
H::OpH
6+
σ::T
7+
function ShiftedData{T, OpH}(H::OpH, σ::T) where {T, OpH}
8+
size(H, 1) == size(H, 2) || throw(DimensionMismatch("Operator H must be square."))
9+
new{T, OpH}(H, σ)
10+
end
11+
end
12+
13+
ShiftedData(H::OpH, σ::T) where {T, OpH} = ShiftedData{T, OpH}(H, σ)
14+
15+
# Forward product: y = α(H + σI)x + βy
16+
function shifted_prod!(y, data::ShiftedData, x, α, β)
17+
# y = α * H * x + β * y
18+
mul!(y, data.H, x, α, β)
19+
20+
# y = y + (α * σ) * x
21+
if !(iszero(data.σ) || iszero(α))
22+
axpy!* data.σ, x, y)
23+
end
24+
return y
25+
end
26+
27+
# Transpose product: y = α(Hᵀ + σI)x + βy
28+
function shifted_tprod!(y, data::ShiftedData, x, α, β)
29+
# y = α * Hᵀ * x + β * y
30+
mul!(y, transpose(data.H), x, α, β)
31+
32+
# y = y + (α * σ) * x
33+
if !(iszero(data.σ) || iszero(α))
34+
axpy!* data.σ, x, y)
35+
end
36+
return y
37+
end
38+
39+
# Conjugate transpose (adjoint) product: y = α(Hᴴ + conj(σ)I)x + βy
40+
function shifted_ctprod!(y, data::ShiftedData, x, α, β)
41+
# y = α * Hᴴ * x + β * y
42+
mul!(y, adjoint(data.H), x, α, β)
43+
44+
# y = y + (α * conj(σ)) * x
45+
if !(iszero(data.σ) || iszero(α))
46+
axpy!* conj(data.σ), x, y)
47+
end
48+
return y
49+
end
50+
51+
"""
52+
ShiftedOperator(H, σ=0)
53+
54+
Construct a linear operator representing `op = H + σI`.
55+
"""
56+
mutable struct ShiftedOperator{T, OpH, F, Ft, Fct} <: AbstractLinearOperator{T}
57+
nrow::Int
58+
ncol::Int
59+
symmetric::Bool
60+
hermitian::Bool
61+
prod!::F # Closure for op * x
62+
tprod!::Ft # Closure for transpose(op) * x
63+
ctprod!::Fct # Closure for adjoint(op) * x
64+
data::ShiftedData{T, OpH}
65+
nprod::Int # Internal counter for products
66+
ntprod::Int # Internal counter for transpose products
67+
nctprod::Int # Internal counter for adjoint products
68+
end
69+
70+
function ShiftedOperator(H::OpH, σ_in::Number = zero(eltype(H))) where {OpH}
71+
T = eltype(H)
72+
σ = convert(T, σ_in) # Enforces that σ matches the element type of H
73+
74+
data = ShiftedData(H, σ)
75+
76+
prod! = (y, x, α, β) -> shifted_prod!(y, data, x, α, β)
77+
tprod! = (y, x, α, β) -> shifted_tprod!(y, data, x, α, β)
78+
ctprod! = (y, x, α, β) -> shifted_ctprod!(y, data, x, α, β)
79+
80+
n = size(H, 1)
81+
82+
is_sym = issymmetric(H)
83+
is_herm = ishermitian(H) && isreal(σ)
84+
85+
return ShiftedOperator(n, n, is_sym, is_herm, prod!, tprod!, ctprod!, data, 0, 0, 0)
86+
end
87+
88+
size(op::ShiftedOperator) = (op.nrow, op.ncol)
89+
issymmetric(op::ShiftedOperator) = op.symmetric
90+
ishermitian(op::ShiftedOperator) = op.hermitian && isreal(op.data.σ)
91+
92+
has_args5(op::ShiftedOperator) = true
93+
use_prod5!(op::ShiftedOperator) = true
94+
95+
isallocated5(op::ShiftedOperator) = true
96+
97+
storage_type(op::ShiftedOperator{T}) where {T} = storage_type(op.data.H)
98+
99+
function reset!(op::ShiftedOperator)
100+
op.nprod = 0
101+
op.ntprod = 0
102+
op.nctprod = 0
103+
return op
104+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ include("test_normest.jl")
1616
include("test_diag.jl")
1717
include("test_chainrules.jl")
1818
include("test_solve_shifted_system.jl")
19+
include("test_shifted_operator.jl")
1920
include("gpu/test_S_kwarg.jl")
2021
include("gpu/jlarrays.jl")
2122
if Sys.isapple() && occursin("arm64", Sys.MACHINE)

test/test_shifted_operator.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
@testset "ShiftedOperator Tests" begin
2+
@testset "Real Symmetric (Float64)" begin
3+
n = 5
4+
H_dense = rand(n, n)
5+
H_dense = H_dense + H_dense'
6+
7+
H_op = LinearOperator(H_dense; symmetric = true, hermitian = true)
8+
9+
σ = 2.0
10+
op = ShiftedOperator(H_op, σ)
11+
12+
A_ref = H_dense + σ * I
13+
14+
x = rand(n)
15+
y = zeros(n)
16+
17+
@test size(op) == (n, n)
18+
@test issymmetric(op) == true
19+
@test ishermitian(op) == true
20+
@test op.data.σ == σ
21+
22+
mul!(y, op, x)
23+
@test y A_ref * x
24+
25+
y_orig = rand(n)
26+
y = copy(y_orig)
27+
α, β = 0.5, -1.0
28+
mul!(y, op, x, α, β)
29+
@test y α * (A_ref * x) + β * y_orig
30+
31+
op_t = transpose(op)
32+
33+
y_t = zeros(n)
34+
mul!(y_t, op_t, x)
35+
@test y_t transpose(A_ref) * x
36+
end
37+
38+
@testset "Complex Non-Hermitian" begin
39+
n = 5
40+
H_dense = rand(ComplexF64, n, n)
41+
H_op = LinearOperator(H_dense)
42+
43+
σ = 1.0 + 2.0im
44+
op = ShiftedOperator(H_op, σ)
45+
46+
@test !issymmetric(op)
47+
@test !ishermitian(op)
48+
49+
A_ref = H_dense + σ * I
50+
x = rand(ComplexF64, n)
51+
52+
op_c = adjoint(op)
53+
54+
y_c = zeros(ComplexF64, n)
55+
mul!(y_c, op_c, x)
56+
@test y_c adjoint(A_ref) * x
57+
end
58+
59+
@testset "Mutation (Updating Sigma)" begin
60+
n = 3
61+
H_dense = rand(n, n)
62+
op = ShiftedOperator(LinearOperator(H_dense), 1.0)
63+
64+
x = ones(n)
65+
y1 = op * x
66+
67+
op.data.σ = 10.0
68+
69+
y2 = op * x
70+
71+
@test !(y1 y2)
72+
@test y2 (H_dense + 10.0*I) * x
73+
end
74+
75+
@testset "Mutation (Dynamic Hermitian Check)" begin
76+
n = 3
77+
H_dense = rand(ComplexF64, n, n)
78+
H_dense = H_dense + H_dense'
79+
H_op = LinearOperator(H_dense; symmetric=false, hermitian=true)
80+
81+
σ = 2.0
82+
op = ShiftedOperator(H_op, σ)
83+
@test ishermitian(op)
84+
85+
op.data.σ = 2.0 + 1.0im
86+
@test !ishermitian(op)
87+
88+
op.data.σ = 3.0
89+
@test ishermitian(op)
90+
end
91+
92+
@testset "Strict Type Constraint" begin
93+
H = LinearOperator(rand(Float32, 5, 5))
94+
σ = 1.0
95+
op = ShiftedOperator(H, σ)
96+
97+
@test eltype(op) == Float32
98+
@test op.data.σ isa Float32
99+
100+
x = rand(Float32, 5)
101+
y = op * x
102+
@test eltype(y) == Float32
103+
end
104+
105+
@testset "Coverage & Utilities" begin
106+
n = 5
107+
H = LinearOperator(rand(n, n))
108+
x = rand(n)
109+
y = zeros(n)
110+
op = ShiftedOperator(H, 2.0)
111+
112+
mul!(y, transpose(op), x, 0.5, 1.0)
113+
114+
mul!(y, transpose(op), x, 0.0, 1.0)
115+
116+
op_zero = ShiftedOperator(H, 0.0)
117+
mul!(y, transpose(op_zero), x)
118+
119+
@test LinearOperators.isallocated5(op) == true
120+
121+
@test LinearOperators.storage_type(op) == LinearOperators.storage_type(H)
122+
123+
op.nprod = 10
124+
reset!(op)
125+
@test op.nprod == 0
126+
end
127+
end

0 commit comments

Comments
 (0)