Skip to content

Commit d2bf084

Browse files
committed
remove bad eltype methods
1 parent 0cbcdc3 commit d2bf084

7 files changed

Lines changed: 22 additions & 20 deletions

File tree

src/helpers/algebra/companion_matrix.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ struct CompanionMatrix{R <: Real, T <: AbstractVector{R}} <: AbstractMatrix{R}
2020
θ::T
2121
end
2222

23-
Base.eltype(::CompanionMatrix{R}) where {R} = R
24-
Base.size(cmatrix::CompanionMatrix) = (length(cmatrix.θ), length(cmatrix.θ))
25-
Base.length(cmatrix::CompanionMatrix) = prod(size(cmatrix))
23+
Base.eltype(::Type{<:CompanionMatrix{R}}) where {R} = R
24+
Base.size(cmatrix::CompanionMatrix) = (length(cmatrix.θ), length(cmatrix.θ))
25+
Base.length(cmatrix::CompanionMatrix) = prod(size(cmatrix))
2626

2727
Base.getindex(cmatrix::CompanionMatrix, i::Int) = getindex(cmatrix, map(r -> r + 1, reverse(divrem(i - 1, first(size(cmatrix)))))...)
2828

@@ -40,9 +40,9 @@ struct CompanionMatrixTransposed{R <: Real, T <: AbstractVector{R}} <: AbstractM
4040
θ::T
4141
end
4242

43-
Base.eltype(::CompanionMatrixTransposed{R}) where {R} = R
44-
Base.size(cmatrix::CompanionMatrixTransposed) = (length(cmatrix.θ), length(cmatrix.θ))
45-
Base.length(cmatrix::CompanionMatrixTransposed) = prod(size(cmatrix))
43+
Base.eltype(::Type{<:CompanionMatrixTransposed{R}}) where {R} = R
44+
Base.size(cmatrix::CompanionMatrixTransposed) = (length(cmatrix.θ), length(cmatrix.θ))
45+
Base.length(cmatrix::CompanionMatrixTransposed) = prod(size(cmatrix))
4646

4747
Base.getindex(cmatrix::CompanionMatrixTransposed, i::Int) = getindex(cmatrix, map(r -> r + 1, reverse(divrem(i - 1, first(size(cmatrix)))))...)
4848

src/helpers/algebra/permutation_matrix.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ function PermutationMatrix(dim::T; switch_first::Bool = true) where {T <: Intege
4040
end
4141

4242
# extensions of base functionality
43-
Base.eltype(::PermutationMatrix{T}) where {T} = T
43+
Base.eltype(::Type{<:PermutationMatrix{T}}) where {T} = T
44+
4445
function Base.size(mat::PermutationMatrix)
4546
nr_elements = length(mat.ind)
4647
return (nr_elements, nr_elements)

src/helpers/algebra/standard_basis_vector.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ function StandardBasisVector(length::Int, index::Int, scale::T = one(Int)) where
2828
return StandardBasisVector{T}(length, index, scale)
2929
end
3030

31-
Base.eltype(::StandardBasisVector{T}) where {T} = T
3231
Base.eltype(::Type{StandardBasisVector{T}}) where {T} = T
3332

3433
# extensions of base functionality

src/marginal.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,13 @@ MacroHelpers.@proxy_methods Marginal getdata [
133133
Base.precision,
134134
Base.length,
135135
Base.ndims,
136-
Base.size,
137-
Base.eltype
136+
Base.size
138137
]
139138

139+
# Eltype is special here, because it should be only defined on types
140+
# Otherwise it causes invalidations and slower compile times
141+
Base.eltype(::Type{<:Marginal{D}}) where {D} = Base.eltype(D)
142+
140143
Distributions.mean(fn::Function, marginal::Marginal) = mean(fn, getdata(marginal))
141144

142145
"""

src/message.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,8 @@ function multiply_messages(prod_strategy, left::Message, right::Message)
136136
return Message(new_dist, is_prod_clamped, is_prod_initial, new_addons)
137137
end
138138

139-
constrain_form_as_message(message::Message, form_constraint) = Message(
140-
constrain_form(form_constraint, getdata(message)), is_clamped(message), is_initial(message), getaddons(message)
141-
)
139+
constrain_form_as_message(message::Message, form_constraint) =
140+
Message(constrain_form(form_constraint, getdata(message)), is_clamped(message), is_initial(message), getaddons(message))
142141

143142
# Note: we need extra Base.Generator(as_message, messages) step here, because some of the messages might be VMP messages
144143
# We want to cast it explicitly to a Message structure (which as_message does in case of DeferredMessage)
@@ -187,10 +186,13 @@ MacroHelpers.@proxy_methods Message getdata [
187186
Base.precision,
188187
Base.length,
189188
Base.ndims,
190-
Base.size,
191-
Base.eltype
189+
Base.size
192190
]
193191

192+
# Eltype is special here, because it should be only defined on types
193+
# Otherwise it causes invalidations and slower compile times
194+
Base.eltype(::Type{<:Message{D}}) where {D} = Base.eltype(D)
195+
194196
Distributions.mean(fn::Function, message::Message) = mean(fn, getdata(message))
195197

196198
## Deferred Message

src/nodes/predefined/autoregressive.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ Base.size(precision::ARPrecisionMatrix) = (precision.order, precision.order)
128128
Base.getindex(precision::ARPrecisionMatrix, i::Int, j::Int) = (i === 1 && j === 1) ? precision.γ : ((i === j) ? convert(eltype(precision), huge) : zero(eltype(precision)))
129129

130130
Base.eltype(::Type{<:ARPrecisionMatrix{T}}) where {T} = T
131-
Base.eltype(::ARPrecisionMatrix{T}) where {T} = T
132131

133132
Base.convert(::Type{AbstractArray{T}}, matrix::ARPrecisionMatrix{R}) where {T, R} = ARPrecisionMatrix(matrix.order, convert(T, matrix.γ))
134133
Base.convert(::Type{AbstractArray{T}}, matrix::ARPrecisionMatrix{T}) where {T} = matrix
@@ -165,7 +164,6 @@ Base.size(transition::ARTransitionMatrix) = (transition.order, transition.order)
165164
Base.getindex(transition::ARTransitionMatrix, i::Int, j::Int) = (i === 1 && j === 1) ? transition.inv_γ : zero(eltype(transition))
166165

167166
Base.eltype(::Type{<:ARTransitionMatrix{T}}) where {T} = T
168-
Base.eltype(::ARTransitionMatrix{T}) where {T} = T
169167

170168
Base.convert(::Type{AbstractArray{T}}, matrix::ARTransitionMatrix{R}) where {T, R} = ARTransitionMatrix{T}(matrix.order, convert(T, matrix.inv_γ))
171169
Base.convert(::Type{AbstractArray{T}}, matrix::ARTransitionMatrix{T}) where {T} = matrix

src/nodes/predefined/gcv.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ ExponentialLinearQuadratic(approximation, a::Real, b::Real, c::Real, d::Real)
1414
ExponentialLinearQuadratic(approximation, a::Integer, b::Integer, c::Integer, d::Integer) = ExponentialLinearQuadratic(approximation, float(a), float(b), float(c), float(d))
1515

1616
Base.eltype(::Type{<:ExponentialLinearQuadratic{A, T}}) where {A, T} = T
17-
Base.eltype(::ExponentialLinearQuadratic{A, T}) where {A, T} = T
1817

1918
Base.precision(dist::ExponentialLinearQuadratic) = mean_invcov(dist)[2]
2019

@@ -79,8 +78,8 @@ default_meta(::Type{GCV}) = DefaultGCVNodeMetadata
7978
@average_energy GCV (q_y_x::MultivariateNormalDistributionsFamily, q_z::NormalDistributionsFamily, q_κ::Any, q_ω::Any, meta::Union{<:GCVMetadata, Nothing}) = begin
8079
y_x_mean, y_x_cov = mean_cov(q_y_x)
8180
z_mean, z_var = mean_var(q_z)
82-
κ_mean, κ_var = mean_var(q_κ)
83-
ω_mean, ω_var = mean_var(q_ω)
81+
κ_mean, κ_var = mean_var(q_κ)
82+
ω_mean, ω_var = mean_var(q_ω)
8483

8584
ksi = (κ_mean^2) * z_var + κ_var * ((z_mean^2) + z_var)
8685
psi = @inbounds (y_x_mean[2] - y_x_mean[1])^2 + y_x_cov[1, 1] + y_x_cov[2, 2] - y_x_cov[1, 2] - y_x_cov[2, 1]

0 commit comments

Comments
 (0)