Skip to content

Commit 4b45cbc

Browse files
committed
Support type promotion in ShiftedOperator
Updated ShiftedOperator to promote types between operator and shift value, ensuring consistent output types. Added tests to verify type promotion and correct behavior when mixing Float32 and Float64 inputs. Also added a reset! function to ShiftedOperator for resetting operation counters.
1 parent 7750f0d commit 4b45cbc

2 files changed

Lines changed: 29 additions & 2 deletions

File tree

src/shifted_operators.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,11 @@ function ShiftedOperator(
9494
)
9595
end
9696

97-
function ShiftedOperator(H::OpH, σ::T = zero(eltype(H))) where {T, OpH}
97+
function ShiftedOperator(H::OpH, σ_in::Number = zero(eltype(H))) where {OpH}
98+
T = promote_type(eltype(H), typeof(σ_in))
9899

99-
data = ShiftedData(H, σ)
100+
data = ShiftedData(H, σ_in)
101+
σ = convert(T, σ_in)
100102

101103
prod! = (y, x, α, β) -> shifted_prod!(y, data, x, α, β)
102104
tprod! = (y, x, α, β) -> shifted_tprod!(y, data, x, α, β)
@@ -130,3 +132,10 @@ function adjoint(op::ShiftedOperator)
130132
# (H + σI)ᴴ = Hᴴ + conj(σ)I
131133
return ShiftedOperator(adjoint(op.data.H), conj(op.data.σ))
132134
end
135+
136+
function reset!(op::ShiftedOperator)
137+
op.nprod = 0
138+
op.ntprod = 0
139+
op.nctprod = 0
140+
return op
141+
end

test/test_shifted_operator.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,22 @@
8585
@test !(y1 y2)
8686
@test y2 (H_dense + 10.0*I) * x
8787
end
88+
@testset "Type Promotion" begin
89+
# H is Float32
90+
H = LinearOperator(rand(Float32, 5, 5))
91+
92+
# σ is Float64
93+
σ = 1.0 # default Float64
94+
95+
op = ShiftedOperator(H, σ)
96+
97+
# The operator itself should be Float64 (promoted)
98+
@test eltype(op) == Float64
99+
@test op.data.σ isa Float64
100+
101+
# Ensure operations return the promoted type
102+
x = rand(Float32, 5)
103+
y = op * x
104+
@test eltype(y) == Float64
105+
end
88106
end

0 commit comments

Comments
 (0)