Skip to content

Handle missing arguments more carefully #1361

@penelopeysm

Description

@penelopeysm

function convert_model_argument(param_eltype, model_argument)
T = typeof(model_argument)
# If the argument contains missing data, then we potentially need to deepcopy it. This
# is because the argument may be e.g. a vector of missings, and evaluating a
# tilde-statement like x[1] ~ Normal() would set x[1] = some_not_missing_value, thus
# mutating x. If you then run the model again with the same argument, x[1] would no
# longer be missing.
return if hasmissing(T)
# It is possible that we could skip the deepcopy, if the argument has to be promoted
# anyway. For example, if we are running with ForwardDiff and the argument is a
# Vector{Union{Missing, Float64}}, then we will convert it to a
# Vector{Union{Missing, ForwardDiff.Dual{...}}} anyway, which will avoid mutating
# the original argument. We can check for this by first converting and then only
# deepcopying if the converted value aliases the original.
# Note that indiscriminately deepcopying can not only lead to reduced performance,
# but sometimes also incorrect behaviour with ReverseDiff.jl, because ReverseDiff
# expects to be able to track array mutations. See e.g.
# https://github.com/TuringLang/DynamicPPL.jl/pull/1015#issuecomment-3166011534
converted_argument = convert(
promote_model_type_argument(param_eltype, T), model_argument
)
if converted_argument === model_argument
deepcopy(model_argument)
else
converted_argument
end
else
model_argument
end
end

this works correctly for arrays that contain missing elements. but arrays of arrays, or mutable structs, will silently yield incorrect results

notice how s.x is a parameter on the first run and not on the second run

julia> using DynamicPPL, Distributions

julia> mutable struct S
           x::Union{Missing,Float64}
       end

julia> @model function f(s::S)
           s.x ~ Normal()
           a ~ Normal()
       end
f (generic function with 2 methods)

julia> model = f(S(missing))
Model{typeof(f), (:s,), (), (), Tuple{S}, Tuple{}, DefaultContext, false}(f, (s = S(missing),), NamedTuple(), DefaultContext())

julia> rand(model)
VarNamedTuple
├─ s => VarNamedTuple
│       └─ x => -0.6584816275651195
└─ a => 1.124233167066466

julia> rand(model)
VarNamedTuple
└─ a => 0.47209376169407236

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions