Skip to content

Commit 4d7c252

Browse files
authored
Merge pull request #489 from ReactiveBayes/fix-invalidations
Fix invalidations from SnoopCompile
2 parents c355e67 + 478ae8c commit 4d7c252

11 files changed

Lines changed: 35 additions & 41 deletions

File tree

src/helpers/algebra/companion_matrix.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ struct CompanionMatrix{R <: Real, T <: AbstractVector{R}} <: AbstractMatrix{R}
2020
θ::T
2121
end
2222

23-
Base.eltype(::CompanionMatrix{R}) where {R} = R
24-
Base.size(cmatrix::CompanionMatrix) = (length(cmatrix.θ), length(cmatrix.θ))
25-
Base.length(cmatrix::CompanionMatrix) = prod(size(cmatrix))
23+
Base.size(cmatrix::CompanionMatrix) = (length(cmatrix.θ), length(cmatrix.θ))
24+
Base.length(cmatrix::CompanionMatrix) = prod(size(cmatrix))
2625

2726
Base.getindex(cmatrix::CompanionMatrix, i::Int) = getindex(cmatrix, map(r -> r + 1, reverse(divrem(i - 1, first(size(cmatrix)))))...)
2827

@@ -40,9 +39,8 @@ struct CompanionMatrixTransposed{R <: Real, T <: AbstractVector{R}} <: AbstractM
4039
θ::T
4140
end
4241

43-
Base.eltype(::CompanionMatrixTransposed{R}) where {R} = R
44-
Base.size(cmatrix::CompanionMatrixTransposed) = (length(cmatrix.θ), length(cmatrix.θ))
45-
Base.length(cmatrix::CompanionMatrixTransposed) = prod(size(cmatrix))
42+
Base.size(cmatrix::CompanionMatrixTransposed) = (length(cmatrix.θ), length(cmatrix.θ))
43+
Base.length(cmatrix::CompanionMatrixTransposed) = prod(size(cmatrix))
4644

4745
Base.getindex(cmatrix::CompanionMatrixTransposed, i::Int) = getindex(cmatrix, map(r -> r + 1, reverse(divrem(i - 1, first(size(cmatrix)))))...)
4846

src/helpers/algebra/permutation_matrix.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ function PermutationMatrix(dim::T; switch_first::Bool = true) where {T <: Intege
3939
return PermutationMatrix(ind)
4040
end
4141

42-
# extensions of base functionality
43-
Base.eltype(::PermutationMatrix{T}) where {T} = T
4442
function Base.size(mat::PermutationMatrix)
4543
nr_elements = length(mat.ind)
4644
return (nr_elements, nr_elements)

src/helpers/algebra/standard_basis_vector.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ function StandardBasisVector(length::Int, index::Int, scale::T = one(Int)) where
2828
return StandardBasisVector{T}(length, index, scale)
2929
end
3030

31-
Base.eltype(::StandardBasisVector{T}) where {T} = T
32-
Base.eltype(::Type{StandardBasisVector{T}}) where {T} = T
33-
3431
# extensions of base functionality
3532
Base.size(e::StandardBasisVector) = (length(e),)
3633
Base.size(e::StandardBasisVector, d) = d::Integer == 1 ? length(e) : 1

src/helpers/helpers.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,8 @@ Base.IteratorSize(::Type{<:SkipIndexIterator}) = HasLength()
6464
Base.IteratorEltype(::Type{<:SkipIndexIterator}) = HasEltype()
6565
Base.IndexStyle(::Type{<:SkipIndexIterator}) = IndexLinear()
6666

67-
Base.eltype(::Type{<:SkipIndexIterator{T}}) where {T} = T
68-
Base.length(iter::SkipIndexIterator) = length(iter.iterator) - 1
69-
Base.size(iter::SkipIndexIterator) = (length(iter),)
67+
Base.length(iter::SkipIndexIterator) = length(iter.iterator) - 1
68+
Base.size(iter::SkipIndexIterator) = (length(iter),)
7069

7170
Base.@propagate_inbounds Base.getindex(iter::SkipIndexIterator, i::Int) = i < skip(iter) ? iter.iterator[i] : iter.iterator[i + 1]
7271
Base.@propagate_inbounds Base.getindex(iter::SkipIndexIterator, i::CartesianIndex{1}) = Base.getindex(iter, first(i.I))

src/marginal.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,13 @@ MacroHelpers.@proxy_methods Marginal getdata [
133133
Base.precision,
134134
Base.length,
135135
Base.ndims,
136-
Base.size,
137-
Base.eltype
136+
Base.size
138137
]
139138

139+
# Eltype is special here, because it should be only defined on types
140+
# Otherwise it causes invalidations and slower compile times
141+
Base.eltype(::Type{<:Marginal{D}}) where {D} = Base.eltype(D)
142+
140143
Distributions.mean(fn::Function, marginal::Marginal) = mean(fn, getdata(marginal))
141144

142145
"""

src/message.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,13 @@ MacroHelpers.@proxy_methods Message getdata [
187187
Base.precision,
188188
Base.length,
189189
Base.ndims,
190-
Base.size,
191-
Base.eltype
190+
Base.size
192191
]
193192

193+
# Eltype is special here, because it should be only defined on types
194+
# Otherwise it causes invalidations and slower compile times
195+
Base.eltype(::Type{<:Message{D}}) where {D} = Base.eltype(D)
196+
194197
Distributions.mean(fn::Function, message::Message) = mean(fn, getdata(message))
195198

196199
## Deferred Message

src/nodes/predefined/autoregressive.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,6 @@ end
127127
Base.size(precision::ARPrecisionMatrix) = (precision.order, precision.order)
128128
Base.getindex(precision::ARPrecisionMatrix, i::Int, j::Int) = (i === 1 && j === 1) ? precision.γ : ((i === j) ? convert(eltype(precision), huge) : zero(eltype(precision)))
129129

130-
Base.eltype(::Type{<:ARPrecisionMatrix{T}}) where {T} = T
131-
Base.eltype(::ARPrecisionMatrix{T}) where {T} = T
132-
133130
Base.convert(::Type{AbstractArray{T}}, matrix::ARPrecisionMatrix{R}) where {T, R} = ARPrecisionMatrix(matrix.order, convert(T, matrix.γ))
134131
Base.convert(::Type{AbstractArray{T}}, matrix::ARPrecisionMatrix{T}) where {T} = matrix
135132

@@ -164,9 +161,6 @@ end
164161
Base.size(transition::ARTransitionMatrix) = (transition.order, transition.order)
165162
Base.getindex(transition::ARTransitionMatrix, i::Int, j::Int) = (i === 1 && j === 1) ? transition.inv_γ : zero(eltype(transition))
166163

167-
Base.eltype(::Type{<:ARTransitionMatrix{T}}) where {T} = T
168-
Base.eltype(::ARTransitionMatrix{T}) where {T} = T
169-
170164
Base.convert(::Type{AbstractArray{T}}, matrix::ARTransitionMatrix{R}) where {T, R} = ARTransitionMatrix{T}(matrix.order, convert(T, matrix.inv_γ))
171165
Base.convert(::Type{AbstractArray{T}}, matrix::ARTransitionMatrix{T}) where {T} = matrix
172166

src/nodes/predefined/gcv.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ ExponentialLinearQuadratic(approximation, a::Real, b::Real, c::Real, d::Real)
1414
ExponentialLinearQuadratic(approximation, a::Integer, b::Integer, c::Integer, d::Integer) = ExponentialLinearQuadratic(approximation, float(a), float(b), float(c), float(d))
1515

1616
Base.eltype(::Type{<:ExponentialLinearQuadratic{A, T}}) where {A, T} = T
17-
Base.eltype(::ExponentialLinearQuadratic{A, T}) where {A, T} = T
1817

1918
Base.precision(dist::ExponentialLinearQuadratic) = mean_invcov(dist)[2]
2019

src/pipeline/logger.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ LoggerPipelineStage() = LoggerPipelineStage(Core.stdout, "Log")
1717
LoggerPipelineStage(output::IO) = LoggerPipelineStage(output, "Log")
1818
LoggerPipelineStage(prefix::String) = LoggerPipelineStage(Core.stdout, prefix)
1919

20-
Base.println(stage::LoggerPipelineStage, something) = Base.println(stage, stage.output, append_prefix(stage, something))
20+
logger_pipeline_stage_println(logger::LoggerPipelineStage, something::Any) = logger_pipeline_stage_println(
21+
logger, logger.output, logger_pipeline_stage_append_prefix(logger, something)
22+
)
2123

22-
Base.println(stage::LoggerPipelineStage, output::Core.CoreSTDOUT, something) = Core.println(output, something)
23-
Base.println(stage::LoggerPipelineStage, output, something) = println(output, something)
24+
logger_pipeline_stage_println(logger::LoggerPipelineStage, output::Core.CoreSTDOUT, something) = Core.println(output, something)
25+
logger_pipeline_stage_println(logger::LoggerPipelineStage, output, something) = println(output, something)
2426

25-
append_prefix(stage::LoggerPipelineStage, something) = string("[", stage.prefix, "]", something)
27+
logger_pipeline_stage_append_prefix(logger::LoggerPipelineStage, something) = lazy"[$(logger.prefix)]$something"
2628

27-
apply_pipeline_stage(stage::LoggerPipelineStage, factornode, tag::Val{T}, stream) where {T} = stream |> tap((v) -> println(stage, "[$(functionalform(factornode))][$(T)]: $v"))
28-
apply_pipeline_stage(stage::LoggerPipelineStage, factornode, tag::Tuple{Val{T}, Int}, stream) where {T} = stream |> tap((v) -> println(stage, "[$(functionalform(factornode))][$(T):$(tag[2])]: $v"))
29+
apply_pipeline_stage(logger::LoggerPipelineStage, factornode, tag::Val{T}, stream) where {T} = stream |> tap((v) -> logger_pipeline_stage_println(logger, lazy"[$(functionalform(factornode))][$(T)]: $v"))
30+
apply_pipeline_stage(logger::LoggerPipelineStage, factornode, tag::Tuple{Val{T}, Int}, stream) where {T} = stream |> tap((v) -> logger_pipeline_stage_println(logger, lazy"[$(functionalform(factornode))][$(T):$(tag[2])]: $v"))

src/rule.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ Base.copy(entry::TestRuleEntryInputSpecification) = TestRuleEntryInputSpecificat
853853
Base.values(entry::TestRuleEntryInputSpecification) = Base.Generator((arg) -> arg.second, entry.arguments)
854854

855855
# Convert the `TestRuleEntryInputSpecification` back into the `Expr` form, e.g `(m_x = ..., q_y = ..., meta = ...)`
856-
function Base.convert(::Type{Expr}, test_entry::TestRuleEntryInputSpecification)
856+
function rule_macro_convert_to_expr(test_entry::TestRuleEntryInputSpecification)
857857
tuple = Expr(:tuple)
858858
tuple.args = map((arg) -> Expr(:(=), arg.first, arg.second), test_entry.arguments)
859859
if !isnothing(test_entry.meta)
@@ -889,8 +889,8 @@ struct TestRuleEntry
889889
end
890890

891891
# Convert the `TestRuleEntry` back into the `Expr` form, e.g `(input = ..., output = ...)`
892-
function Base.convert(::Type{Expr}, test_entry::TestRuleEntry)
893-
return Expr(:tuple, Expr(:(=), :input, convert(Expr, test_entry.input)), Expr(:(=), :output, test_entry.output))
892+
function rule_macro_convert_to_expr(test_entry::TestRuleEntry)
893+
return Expr(:tuple, Expr(:(=), :input, rule_macro_convert_to_expr(test_entry.input)), Expr(:(=), :output, test_entry.output))
894894
end
895895

896896
# This function takes a `test` parameter which is expected to be an expression of single test entry.
@@ -937,8 +937,8 @@ end
937937

938938
function test_rules_generate_testset(test_entry::TestRuleEntry, invoke_test_fn, call_macro_fn, rule_specification, configuration)
939939
# `nothing` here is a `LineNumberNode`, macrocall expects a `line` number, but we do not have it here
940-
actual_inputs = convert(Expr, test_entry.input)
941-
actual_output = Expr(:macrocall, call_macro_fn, nothing, rule_specification, convert(Expr, actual_inputs))
940+
actual_inputs = rule_macro_convert_to_expr(test_entry.input)
941+
actual_output = Expr(:macrocall, call_macro_fn, nothing, rule_specification, actual_inputs)
942942
expected_output = test_entry.output
943943
rule_spec_str = "$rule_specification"
944944
rule_inputs_str = "$actual_inputs"

0 commit comments

Comments
 (0)