From ab8a0f8765516908d1cab986af697f81f4c10eb6 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Fri, 12 Feb 2021 16:23:04 -0500 Subject: [PATCH 1/8] draft of serialization --- src/Gen.jl | 3 + src/dynamic/dynamic.jl | 1 + src/dynamic/serialization.jl | 47 +++++++++++ src/dynamic/trace.jl | 3 + src/modeling_library/call_at/call_at.jl | 7 ++ src/modeling_library/choice_at/choice_at.jl | 8 ++ src/modeling_library/custom_determ.jl | 7 ++ src/modeling_library/map/map.jl | 2 + src/modeling_library/recurse/recurse.jl | 21 +++++ src/modeling_library/switch/switch.jl | 12 +++ src/modeling_library/unfold/unfold.jl | 2 + src/modeling_library/vector.jl | 17 ++++ src/serialization.jl | 94 +++++++++++++++++++++ src/static_ir/static_ir.jl | 7 +- src/static_ir/trace.jl | 46 +++++++++- src/trie.jl | 25 +++++- 16 files changed, 295 insertions(+), 7 deletions(-) create mode 100644 src/dynamic/serialization.jl create mode 100644 src/serialization.jl diff --git a/src/Gen.jl b/src/Gen.jl index a3a40e38b..dc7b43363 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -49,6 +49,9 @@ include("trie.jl") # generative function interface include("gen_fn_interface.jl") +# serialization/deserialization for traces +include("serialization.jl") + # built-in data types for arg-diff and ret-diff values include("diff.jl") diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index d83055444..2411213bf 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -175,6 +175,7 @@ function gen_fn_changed_error(addr) error("Generative function changed at address: $addr") end +include("serialization.jl") include("simulate.jl") include("generate.jl") include("propose.jl") diff --git a/src/dynamic/serialization.jl b/src/dynamic/serialization.jl new file mode 100644 index 000000000..88c03203c --- /dev/null +++ b/src/dynamic/serialization.jl @@ -0,0 +1,47 @@ +function _record_to_serializable(r::ChoiceOrCallRecord{T}) where {T <: Trace} + @assert !r.is_choice + return ChoiceOrCallRecord(to_serializable_trace(r.subtrace_or_retval), r.score, r.noise, r.is_choice) +end +function _record_to_serializable(r::ChoiceOrCallRecord) + @assert r.is_choice + return r +end +function _record_from_serializable(r::ChoiceOrCallRecord{T}, gf::GenerativeFunction) where {T <: SerializableTrace} + @assert !r.is_choice + return ChoiceOrCallRecord(from_serializable_trace(r.subtrace_or_retval, gf), r.score, r.noise, r.is_choice) +end +function _record_from_serializable(r::ChoiceOrCallRecord, dist::Distribution) + @assert r.is_choice + return r +end +function _trie_to_serializable(trie::Trie) + triemap(trie, identity, _record_to_serializable) +end +function to_serializable_trace(tr::DynamicDSLTrace) + return GenericST( + _trie_to_serializable(tr.trie), + (tr.isempty, tr.score, tr.noise, tr.args, tr.retval) + ) +end + +# since a Dynamic Gen Function doesn't store +# what sub-generative-function is at which address, +# we have to run the generative function to get access to this! +mutable struct GFDeserializeState + trace::DynamicDSLTrace + serialized::GenericST +end +function from_serializable_trace(st::GenericST, gen_fn::DynamicDSLFunction{T}) where T + trace = DynamicDSLTrace{T}(gen_fn, Trie{Any, ChoiceOrCallRecord}(), st.properties...) + state = GFDeserializeState(trace, st) + exec(gen_fn, state, trace.args) + return trace +end +function traceat(state::GFDeserializeState, dist_or_gen_fn, args, key) + record = _record_from_serializable(state.serialized.subtraces[key], dist_or_gen_fn) + state.trace.trie[key] = record + return record.is_choice ? record.subtrace_or_retval : get_retval(record.subtrace_or_retval) +end +function splice(state::GFDeserializeState, gf::DynamicDSLFunction, args::Tuple) + return exec(gf, state, args) +end \ No newline at end of file diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 8c02eceb5..a72e475d9 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -43,6 +43,9 @@ mutable struct DynamicDSLTrace{T} <: Trace # retval is not known yet new(gen_fn, trie, true, 0, 0, args) end + function DynamicDSLTrace{T}(gen_fn::T, trie, isempty, score, noise, args, retval) where {T} + new(gen_fn, trie, isempty, score, noise, args, retval) + end end set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval) diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 88d59f3f5..c2917c230 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -157,4 +157,11 @@ function accumulate_param_gradients!(trace::CallAtTrace, retval_grad) (kernel_input_grads..., nothing) end +function to_serializable_trace(tr::CallAtTrace) + return GenericST(to_serializable_trace(tr.subtrace), tr.key) +end +function from_serializable_trace(st::GenericST, gf::CallAtCombinator) + return get_trace_type(gf)(gf, from_serializable_trace(st.subtraces, gf.kernel), st.properties) +end + export call_at diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl index 09cb922fa..ebf8278a9 100644 --- a/src/modeling_library/choice_at/choice_at.jl +++ b/src/modeling_library/choice_at/choice_at.jl @@ -172,4 +172,12 @@ function accumulate_param_gradients!(trace::ChoiceAtTrace, retval_grad) (kernel_arg_grads[2:end]..., nothing) end + +function to_serializable_trace(tr::ChoiceAtTrace) + return GenericST(nothing, (tr.value, tr.key, tr.kernel_args, tr.score)) +end +function from_serializable_trace(st::GenericST, gf::ChoiceAtCombinator) + return get_trace_type(gf)(gf, st.properties...) +end + export choice_at diff --git a/src/modeling_library/custom_determ.jl b/src/modeling_library/custom_determ.jl index 24d6d90f2..39021abdc 100644 --- a/src/modeling_library/custom_determ.jl +++ b/src/modeling_library/custom_determ.jl @@ -204,4 +204,11 @@ has_argument_grads(gen_fn::CustomUpdateGF) = tuple(fill(nothing, num_args(gen_fn apply_with_state(gen_fn::CustomUpdateGF, args) = error("not implemented") +function to_serializable_trace(tr::CustomDetermGFTrace) + return GenericST(nothing, (tr.retval, tr.state, tr.args)) +end +function from_serializable_trace(st::GenericST, gf::CustomDetermGF) + return get_trace_type(gf)(gf, st.properties...) +end + export CustomUpdateGF, num_args diff --git a/src/modeling_library/map/map.jl b/src/modeling_library/map/map.jl index 1bb695ff7..b94c9f852 100644 --- a/src/modeling_library/map/map.jl +++ b/src/modeling_library/map/map.jl @@ -43,6 +43,8 @@ function get_prev_and_new_lengths(args::Tuple, prev_trace) (new_length, prev_length) end +_addr_to_gen_fn(gf::Map, _) = gf.kernel + include("assess.jl") include("propose.jl") include("simulate.jl") diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index f89f9b987..4034d680e 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -193,6 +193,27 @@ function get_aggregation_constraints(constraints::ChoiceMap, cur::Int) get_submap(constraints, (cur, Val(:aggregation))) end +function to_serializable_trace(tr::RecurseTrace) + return GenericST( + ( + Dict(k => to_serializable_trace(subtr) for (k, tr) in tr.production_traces), + Dict(k => to_serializable_trace(subtr) for (k, tr) in tr.aggregation_traces) + ), + tr.max_branch, tr.score, tr.root_idx, tr.num_has_choices + ) +end +function from_serializable_trace(st::GenericST, gf::Recurse) + production_traces = PersistentHashMap{TODO}( + k => from_serializable_trace(subst, gf.production_kern) + for (k, subst) in st.subtraces[1] + ) + aggregation_traces = PersistentHashMap{TODO}( + k => from_serializable_trace(subst, gf.aggregation_kern) + for (k, subst) in st.subtraces[2] + ) + return get_trace_type(gf)(gf, production_traces, aggregation_traces, st.properties...) +end + ############ # simulate # ############ diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 821143448..d0eb1dffe 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -29,6 +29,18 @@ function (gen_fn::Switch{C})(index::C, args...) where C retval end +function to_serializable_trace(tr::SwitchTrace) + GenericST(to_serializable_trace(tr.branch), (tr.index, tr.retval, tr.args, tr.score, tr.noise)) +end +function from_serializable_trace(c::GenericST, gf::Switch) + (index, retval, args, score, noise) = c.properties + GenericST( + gf, index, + from_serializable_trace(c.subtraces, gf.branches[index]), + retval, args, score, noise + ) +end + include("assess.jl") include("propose.jl") include("simulate.jl") diff --git a/src/modeling_library/unfold/unfold.jl b/src/modeling_library/unfold/unfold.jl index 44238e3b7..7b4bae703 100644 --- a/src/modeling_library/unfold/unfold.jl +++ b/src/modeling_library/unfold/unfold.jl @@ -60,6 +60,8 @@ function check_length(len::Int) end end +_addr_to_gen_fn(gf::Unfold, _) = gf.kernel + include("simulate.jl") include("generate.jl") include("propose.jl") diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index f35360b50..0bf804468 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -184,3 +184,20 @@ function vector_remove_deleted_applications(subtraces, retval, prev_length, new_ end (subtraces, retval) end + +################# +# Serialization # +################# +function to_serializable_trace(trace::VectorTrace) + GenericST( + [to_serializable_trace(st) for st in trace.subtraces], + (trace.retval, trace.args, trace.len, trace.num_nonempty) + ) +end +function from_serializable_trace(st::GenericST, gf::GenerativeFunction{<:Any, VectorTrace{GenFnType, T, U}}) where {GenFnType, T, U} + PersistentVector{U}( + from_serializable_trace(serialized_subtrace, _gen_fn_at_addr(gf, i)) + for (i, serialized_subtrace) in st.subtraces + ) + get_trace_type(gf)(gf, st.properties...) +end diff --git a/src/serialization.jl b/src/serialization.jl new file mode 100644 index 000000000..80102486d --- /dev/null +++ b/src/serialization.jl @@ -0,0 +1,94 @@ +using Serialization: serialize, deserialize + +""" + SerializableTrace + +A representation of a `Trace` which can be serialized. Obtainable via `to_serializable_trace`. +This does not need to contain the `GenerativeFunction` which produced the trace; +to deserialize (using `from_serializable_trace`), the `GenerativeFunction` must be provided. +""" +abstract type SerializableTrace end + +""" + to_serializable_trace(trace::Trace) + +Get a SerializableTrace representing the `trace` in a serializable manner. +""" +function to_serializable_trace(trace::Trace) + return GenericST(trace) +end + +""" + from_serializable_trace(st::SerializableTrace, fn::GenerativeFunction) + +Get the trace of the given generative function encoded by the serializable trace object. +""" +function from_serializable_trace(::SerializableTrace, ::GenerativeFunction) + error("Not implemented.") +end + +""" + DefaultST <: SerializableTrace + +A serializable trace which serializes by attempting to call `Base.serialize` +on the original trace object. + +Many trace types cannot be reliably serialized using this. +""" +struct DefaultST{T} <: SerializableTrace + trace::T + GenericST(trace::T) where {T <: Trace} = new{T}(trace) +end +from_serializable_trace(st::DefaultST, ::GenerativeFunction) = st.trace + +# """ +# ChoiceMapST <: SerializableTrace + +# A serializable trace which encodes a choicemap, +# and uses `Gen.generate` with the encoded choicemap to deserialize. + +# This may not save untraced randomness in a trace. +# """ +# struct ChoiceMapST{A, C} <: SerializableTrace +# args::A +# cm::C +# ChoiceMapST(args::Tuple, cm::ChoiceMap) = new(args, cm) +# end +# function from_serializable_trace(st::ChoiceMapST, gf::GenerativeFunction) +# trace, _ = generate(gf, st.args, st.cm) +# return trace +# end + +""" + serialize_trace(stream::IO, trace::Trace) + serialize_trace(filename::AbstractString, trace::Trace) + +Serialize the given trace to the given stream or file, by converting to a `SerializableTrace`. +""" +function serialize_trace(filename_or_io::Union{IO, AbstractString}, trace::Trace) + return serialize(filename_or_io, to_serializable_trace(trace)) +end + +""" + deserialize_trace(stream::IO, gen_fn::GenerativeFunction) + deserialize_trace(filename::AbstractString, gen_fn::GenerativeFunction) + +Deserialize the trace for the given generative function stored in the given stream or file +(as saved via `serialize_trace`). +""" +function deserialize_trace(filename_or_io::Union{IO, AbstractString}, gf::GenerativeFunction) + return from_serializable_trace(deserialize(filename_or_io), gf) +end + +""" + GenericST <: SerializableTrace + +A SerializableTrace which contains some subtraces which have been recursively converted +to `SerializableTrace`s, and some properties which are directly serializable. +""" +struct GenericST{S, P} <: SerializableTrace + subtraces::S + properties::P +end + +export to_serializable_trace, from_serializable_trace, serialize_trace, deserialize_trace \ No newline at end of file diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index 5b156d0aa..3f492a36c 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -38,7 +38,7 @@ end function generate_generative_function(ir::StaticIR, name::Symbol, options::StaticIRGenerativeFunctionOptions) - (trace_defns, trace_struct_name) = generate_trace_type_and_methods(ir, name, options) + (trace_defns, trace_struct_name, tracefields) = generate_trace_type_and_methods(ir, name, options) gen_fn_type_name = gensym("StaticGenFunction_$name") return_type = ir.return_node.typ @@ -60,7 +60,10 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati $(GlobalRef(Gen, :get_gen_fn_type))(::Type{$trace_struct_name}) = $gen_fn_type_name $(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) end - Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) + + serialization_code = generate_serialization_methods(ir, trace_struct_name, gen_fn_type_name, tracefields) + + Expr(:block, trace_defns, gen_fn_defn, serialization_code, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) end include("render_ir.jl") diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index de2c84b30..9cc1464e0 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -86,7 +86,9 @@ const return_value_fieldname = gensym("retval") struct TraceField fieldname::Symbol typ::Union{Symbol,Expr,QuoteNode} + holds_subtrace::Bool end +TraceField(f, t) = TraceField(f, t, false) function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptions) fields = TraceField[] @@ -103,7 +105,7 @@ function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptio for node in ir.call_nodes subtrace_fieldname = get_subtrace_fieldname(node) subtrace_type = QuoteNode(get_trace_type(node.generative_function)) - push!(fields, TraceField(subtrace_fieldname, subtrace_type)) + push!(fields, TraceField(subtrace_fieldname, subtrace_type, true)) end if options.cache_julia_nodes for node in ir.julia_nodes @@ -124,8 +126,44 @@ function generate_trace_struct(ir::StaticIR, trace_struct_name::Symbol, options: mutable = false fields = get_trace_fields(ir, options) field_exprs = map((f) -> Expr(:(::), f.fieldname, f.typ), fields) - Expr(:struct, mutable, Expr(:(<:), trace_struct_name, QuoteNode(StaticIRTrace)), + return ( + fields, + Expr(:struct, mutable, Expr(:(<:), trace_struct_name, QuoteNode(StaticIRTrace)), Expr(:block, field_exprs..., Expr(:(::), static_ir_gen_fn_ref, QuoteNode(Any)))) + ) +end + +function generate_serialization_methods(ir::StaticIR, trace_struct_name::Symbol, gen_fn_typename::Symbol, fields) + to_subtraces_exprs = [:(tr.$(field.fieldname)) for field in fields if field.holds_subtrace] + to_properties_exprs = [:(tr.$(field.fieldname)) for field in fields if !field.holds_subtrace] + + # fields will have a bunch of properties, then the subtraces, then more properties + + num_initial_props = 0 + for field in fields + if !field.holds_subtrace + num_initial_props += 1 + else + break; + end + end + gen_fns = [QuoteNode(node.generative_function) for node in ir.call_nodes] + + quote + function $(GlobalRef(Gen, :to_serializable_trace))(tr::$trace_struct_name) + return $(GlobalRef(Gen, :GenericST))( + $(Expr(:tuple, to_subtraces_exprs...)), + $(Expr(:tuple, to_properties_exprs...)) + ) + end + function $(GlobalRef(Gen, :from_serializable_trace))(st::$(GlobalRef(Gen, :GenericST)), gf::$gen_fn_typename) + return $trace_struct_name( + st.properties[1:$num_initial_props]..., + ($(GlobalRef(Gen, :from_serializable_trace))(args...) for args in zip(st.subtraces, $gen_fns))..., + st.properties[$(num_initial_props + 1):end]... + ) + end + end end function generate_isempty(trace_struct_name::Symbol) @@ -284,7 +322,7 @@ end function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::StaticIRGenerativeFunctionOptions) trace_struct_name = gensym("StaticIRTrace_$name") - trace_struct_expr = generate_trace_struct(ir, trace_struct_name, options) + (fields, trace_struct_expr) = generate_trace_struct(ir, trace_struct_name, options) isempty_expr = generate_isempty(trace_struct_name) get_score_expr = generate_get_score(trace_struct_name) get_args_expr = generate_get_args(ir, trace_struct_name) @@ -303,7 +341,7 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_choices_expr, get_schema_expr, get_values_shallow_expr, get_submaps_shallow_expr, static_get_value_exprs..., static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...) - (exprs, trace_struct_name) + (exprs, trace_struct_name, fields) end export StaticIRTrace diff --git a/src/trie.jl b/src/trie.jl index 0d1c2a8a2..e623d167b 100644 --- a/src/trie.jl +++ b/src/trie.jl @@ -2,7 +2,7 @@ # Trie # ################## -struct Trie{K,V} <: ChoiceMap +struct Trie{K,V} leaf_nodes::Dict{K,V} internal_nodes::Dict{K,Trie{K,V}} end @@ -32,6 +32,7 @@ Base.isempty(trie::Trie) = isempty(trie.leaf_nodes) && isempty(trie.internal_nod get_leaf_nodes(trie::Trie) = trie.leaf_nodes get_internal_nodes(trie::Trie) = trie.internal_nodes +Base.:(==)(t1::Trie, t2::Trie) = get_leaf_nodes(t1) == get_leaf_nodes(t2) && get_internal_nodes(t1) == get_internal_nodes(t2) function Base.values(trie::Trie) iterators = convert(Vector{Any}, collect(map(values, values(trie.internal_nodes)))) push!(iterators, values(trie.leaf_nodes)) @@ -179,6 +180,28 @@ Base.haskey(trie::Trie, key) = has_leaf_node(trie, key) Base.getindex(trie::Trie, key) = get_leaf_node(trie, key) +""" + triemap(trie::Trie, key_converter, leaf_converter) + +Get a new trie by applying the function `key_converter` to every key in the trie +and applying the function `leaf_converter` to every leaf node in the trie. +""" +function triemap(trie::Trie{K, V}, key_converter, leaf_converter) where {K, V} + new_keytype = Core.Compiler.return_type(key_converter, Tuple{K}) + KT = Union{Base.return_types(key_converter, (K,))...} + LT = Union{Base.return_types(leaf_converter, (V,))...} + converted_leafs = Dict{KT, LT}( + key_converter(k) => leaf_converter(v) for (k, v) in trie.leaf_nodes + ) + converted_internals = Dict{KT, Trie{KT, LT}}( + key_converter(k) => convert_to_serializable_trie(subtrie, key_converter, leaf_converter, new_keytype, new_leaftype) + for (k, subtrie) in trie.internal_nodes + ) + + return Trie{KT, LT}(converted_leafs, converted_internals) +end + + export Trie export set_internal_node! export delete_internal_node! From 9a25f23225b854137758109cce35e5f1a99adaf9 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Fri, 12 Feb 2021 21:24:28 -0500 Subject: [PATCH 2/8] progress debugging --- src/dynamic/serialization.jl | 3 ++- src/modeling_library/choice_at/choice_at.jl | 2 +- src/modeling_library/custom_determ.jl | 2 +- src/modeling_library/map/map.jl | 2 +- src/modeling_library/recurse/recurse.jl | 24 ++++++++++----------- src/modeling_library/unfold/unfold.jl | 2 +- src/modeling_library/vector.jl | 12 +++++------ src/serialization.jl | 4 ++-- src/static_ir/static_ir.jl | 2 +- test/dsl/dynamic_dsl.jl | 22 +++++++++++++++++++ test/dsl/static_dsl.jl | 5 +++++ test/modeling_library/call_at.jl | 4 ++++ test/modeling_library/choice_at.jl | 4 ++++ test/modeling_library/custom_determ.jl | 3 +++ test/modeling_library/map.jl | 5 +++++ test/modeling_library/recurse.jl | 2 ++ test/modeling_library/switch.jl | 2 ++ test/modeling_library/unfold.jl | 5 +++++ test/runtests.jl | 20 ++++++++++++++++- 19 files changed, 98 insertions(+), 27 deletions(-) diff --git a/src/dynamic/serialization.jl b/src/dynamic/serialization.jl index 88c03203c..1d5a9c102 100644 --- a/src/dynamic/serialization.jl +++ b/src/dynamic/serialization.jl @@ -44,4 +44,5 @@ function traceat(state::GFDeserializeState, dist_or_gen_fn, args, key) end function splice(state::GFDeserializeState, gf::DynamicDSLFunction, args::Tuple) return exec(gf, state, args) -end \ No newline at end of file +end +read_param(::GFDeserializeState, ::Symbol) = nothing \ No newline at end of file diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl index ebf8278a9..f37b64bfe 100644 --- a/src/modeling_library/choice_at/choice_at.jl +++ b/src/modeling_library/choice_at/choice_at.jl @@ -25,6 +25,7 @@ function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap} end get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr) has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr) +has_value(choices::ChoiceAtChoiceMap, addr) = addr == choices.key function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K} choices.key == addr ? choices.value : throw(KeyError(choices, addr)) end @@ -172,7 +173,6 @@ function accumulate_param_gradients!(trace::ChoiceAtTrace, retval_grad) (kernel_arg_grads[2:end]..., nothing) end - function to_serializable_trace(tr::ChoiceAtTrace) return GenericST(nothing, (tr.value, tr.key, tr.kernel_args, tr.score)) end diff --git a/src/modeling_library/custom_determ.jl b/src/modeling_library/custom_determ.jl index 39021abdc..74a2dedaf 100644 --- a/src/modeling_library/custom_determ.jl +++ b/src/modeling_library/custom_determ.jl @@ -208,7 +208,7 @@ function to_serializable_trace(tr::CustomDetermGFTrace) return GenericST(nothing, (tr.retval, tr.state, tr.args)) end function from_serializable_trace(st::GenericST, gf::CustomDetermGF) - return get_trace_type(gf)(gf, st.properties...) + return get_trace_type(gf)(st.properties..., gf) end export CustomUpdateGF, num_args diff --git a/src/modeling_library/map/map.jl b/src/modeling_library/map/map.jl index b94c9f852..5aaa63ab6 100644 --- a/src/modeling_library/map/map.jl +++ b/src/modeling_library/map/map.jl @@ -43,7 +43,7 @@ function get_prev_and_new_lengths(args::Tuple, prev_trace) (new_length, prev_length) end -_addr_to_gen_fn(gf::Map, _) = gf.kernel +_gen_fn_at_addr(gf::Map, _) = gf.kernel include("assess.jl") include("propose.jl") diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 4034d680e..23bbad0d7 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -196,21 +196,21 @@ end function to_serializable_trace(tr::RecurseTrace) return GenericST( ( - Dict(k => to_serializable_trace(subtr) for (k, tr) in tr.production_traces), - Dict(k => to_serializable_trace(subtr) for (k, tr) in tr.aggregation_traces) + Dict(k => to_serializable_trace(subtr) for (k, subtr) in tr.production_traces), + Dict(k => to_serializable_trace(subtr) for (k, subtr) in tr.aggregation_traces) ), - tr.max_branch, tr.score, tr.root_idx, tr.num_has_choices + (tr.max_branch, tr.score, tr.root_idx, tr.num_has_choices) ) end -function from_serializable_trace(st::GenericST, gf::Recurse) - production_traces = PersistentHashMap{TODO}( - k => from_serializable_trace(subst, gf.production_kern) - for (k, subst) in st.subtraces[1] - ) - aggregation_traces = PersistentHashMap{TODO}( - k => from_serializable_trace(subst, gf.aggregation_kern) - for (k, subst) in st.subtraces[2] - ) +function from_serializable_trace(st::GenericST, gf::Recurse{S, T}) where {S, T} + production_traces = PersistentHashMap{Int, S}() + for (k, subst) in st.subtraces[1] + production_traces = assoc(production_traces, k, from_serializable_trace(subst, gf.production_kern)) + end + aggregation_traces = PersistentHashMap{Int, T}() + for (k, subst) in st.subtraces[2] + aggregation_traces = assoc(aggregation_traces, k, from_serializable_trace(subst, gf.aggregation_kern)) + end return get_trace_type(gf)(gf, production_traces, aggregation_traces, st.properties...) end diff --git a/src/modeling_library/unfold/unfold.jl b/src/modeling_library/unfold/unfold.jl index 7b4bae703..ae077a5eb 100644 --- a/src/modeling_library/unfold/unfold.jl +++ b/src/modeling_library/unfold/unfold.jl @@ -60,7 +60,7 @@ function check_length(len::Int) end end -_addr_to_gen_fn(gf::Unfold, _) = gf.kernel +_gen_fn_at_addr(gf::Unfold, _) = gf.kernel include("simulate.jl") include("generate.jl") diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index 0bf804468..bf4db68ce 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -191,13 +191,13 @@ end function to_serializable_trace(trace::VectorTrace) GenericST( [to_serializable_trace(st) for st in trace.subtraces], - (trace.retval, trace.args, trace.len, trace.num_nonempty) + (trace.retval, trace.args, trace.len, trace.num_nonempty, trace.score, trace.noise) ) end function from_serializable_trace(st::GenericST, gf::GenerativeFunction{<:Any, VectorTrace{GenFnType, T, U}}) where {GenFnType, T, U} - PersistentVector{U}( - from_serializable_trace(serialized_subtrace, _gen_fn_at_addr(gf, i)) - for (i, serialized_subtrace) in st.subtraces + subtraces = PersistentVector{U}( + [from_serializable_trace(serialized_subtrace, _gen_fn_at_addr(gf, i)) + for (i, serialized_subtrace) in enumerate(st.subtraces)] ) - get_trace_type(gf)(gf, st.properties...) -end + get_trace_type(gf)(gf, subtraces, st.properties...) +end \ No newline at end of file diff --git a/src/serialization.jl b/src/serialization.jl index 80102486d..1c8b4869d 100644 --- a/src/serialization.jl +++ b/src/serialization.jl @@ -15,7 +15,7 @@ abstract type SerializableTrace end Get a SerializableTrace representing the `trace` in a serializable manner. """ function to_serializable_trace(trace::Trace) - return GenericST(trace) + return DefaultST(trace) end """ @@ -37,7 +37,7 @@ Many trace types cannot be reliably serialized using this. """ struct DefaultST{T} <: SerializableTrace trace::T - GenericST(trace::T) where {T <: Trace} = new{T}(trace) + DefaultST(trace::T) where {T <: Trace} = new{T}(trace) end from_serializable_trace(st::DefaultST, ::GenerativeFunction) = st.trace diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index 3f492a36c..e85bf576a 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -63,7 +63,7 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati serialization_code = generate_serialization_methods(ir, trace_struct_name, gen_fn_type_name, tracefields) - Expr(:block, trace_defns, gen_fn_defn, serialization_code, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) + Expr(:block, trace_defns, gen_fn_defn #=, serialization_code=#, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) end include("render_ir.jl") diff --git a/test/dsl/dynamic_dsl.jl b/test/dsl/dynamic_dsl.jl index 25e3a50f4..a4672b31a 100644 --- a/test/dsl/dynamic_dsl.jl +++ b/test/dsl/dynamic_dsl.jl @@ -534,4 +534,26 @@ end end +@testset "serialization" begin + @gen function bar() + @trace(normal(0, 1), :a) + end + + @gen function baz() + @trace(normal(0, 1), :b) + end + + @gen function foo() + if @trace(bernoulli(0.4), :branch) + @trace(normal(0, 1), :x) + @trace(bar(), :u) + else + @trace(normal(0, 1), :y) + @trace(baz(), :v) + end + end + tr = simulate(foo, ()) + @test serialize_loop_successful(tr) +end + end diff --git a/test/dsl/static_dsl.jl b/test/dsl/static_dsl.jl index e062d06f9..85aa3c689 100644 --- a/test/dsl/static_dsl.jl +++ b/test/dsl/static_dsl.jl @@ -603,4 +603,9 @@ ch = get_choices(tr) @test length(get_submaps_shallow(ch)) == 1 end +@testset "serialization" begin + tr = simulate(model, ([1., 2., 3., 4.],)) + @test serialize_loop_successful(tr) +end + end # @testset "static DSL" diff --git a/test/modeling_library/call_at.jl b/test/modeling_library/call_at.jl index b27f0130d..a79f12e3b 100644 --- a/test/modeling_library/call_at.jl +++ b/test/modeling_library/call_at.jl @@ -55,6 +55,10 @@ (trace, y) end + @testset "serialization" begin + @test serialize_loop_successful(get_trace()[1]) + end + @testset "project" begin (trace, y) = get_trace() @test isapprox(project(trace, EmptySelection()), 0.) diff --git a/test/modeling_library/choice_at.jl b/test/modeling_library/choice_at.jl index 080b1b461..d241d5195 100644 --- a/test/modeling_library/choice_at.jl +++ b/test/modeling_library/choice_at.jl @@ -49,6 +49,10 @@ trace end + @testset "serialization" begin + @test serialize_loop_successful(get_trace()) + end + @testset "project" begin trace = get_trace() @test isapprox(project(trace, EmptySelection()), 0.) diff --git a/test/modeling_library/custom_determ.jl b/test/modeling_library/custom_determ.jl index a0d473c14..8c36ab0f7 100644 --- a/test/modeling_library/custom_determ.jl +++ b/test/modeling_library/custom_determ.jl @@ -84,6 +84,9 @@ @test w == 0. @test get_retval(trace) == 1 + 2 + 3 + # serialization + @test serialize_loop_successful(trace) + # update (UnknownChange) trace = simulate(MyDeterministicGF(), ([1, 2, 3],)) new_trace, w, retdiff = update(trace, ([1, 2, 4],), (UnknownChange(),), EmptyChoiceMap()) diff --git a/test/modeling_library/map.jl b/test/modeling_library/map.jl index 3c1f820fe..32b590070 100644 --- a/test/modeling_library/map.jl +++ b/test/modeling_library/map.jl @@ -38,6 +38,11 @@ @test isapprox(weight, logpdf(normal, z1, 4., 1.)) end + @testset "serialization" begin + (trace, _) = generate(bar, (xs[1:2], ys[1:2])) + @test serialize_loop_successful(trace) + end + @testset "propose" begin (choices, weight) = propose(bar, (xs[1:2], ys[1:2])) z1 = choices[1 => :z] diff --git a/test/modeling_library/recurse.jl b/test/modeling_library/recurse.jl index 7fe7a9592..fdfbb63d2 100644 --- a/test/modeling_library/recurse.jl +++ b/test/modeling_library/recurse.jl @@ -177,6 +177,8 @@ end @test choices[(4, Val(:production)) => :rule] == 4 @test choices[(4, Val(:aggregation)) => :prefix] == false + @test serialize_loop_successful(trace) + # update non-structure choice new_constraints = choicemap() new_constraints[(3, Val(:aggregation)) => :prefix] = false diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 8c183aa16..2b558e762 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -13,6 +13,8 @@ @test swtr[:z] == tr[:z] @test project(swtr, AllSelection()) == project(swtr.branch, AllSelection()) @test project(swtr, EmptySelection()) == swtr.noise + + @test serialize_loop_successful(tr) end # ------------ Bare combinator ------------ # diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index a9a203816..f97700fc7 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -15,6 +15,11 @@ @test length(foo(5, 0., 1.0, 1.0)) == 5 end + @testset "serialization" begin + tr = simulate(foo, (3, 0.1, 0.2, 0.3)) + @test serialize_loop_successful(tr) + end + @testset "simulate" begin x_init = 0.1 alpha = 0.2 diff --git a/test/runtests.jl b/test/runtests.jl index 6fddf20b4..0ff95ab0f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -76,6 +76,24 @@ end const dx = 1e-6 +""" +Attempts to serialize then deserialize the given trace, and returns +whether the pre-serialization and post-serialization traces are equal. +""" +function serialize_loop_successful(tr) + io = IOBuffer() + serialize_trace(io, tr) + seek(io, 0) + des_tr = deserialize_trace(io, get_gen_fn(tr)) + + if get_choices(des_tr) != get_choices(tr) + display(tr) + display(des_tr) + end + + return get_choices(des_tr) == get_choices(tr) +end + include("autodiff.jl") include("diff.jl") include("selection.jl") @@ -86,4 +104,4 @@ include("optional_args.jl") include("static_ir/static_ir.jl") include("tilde_sugar.jl") include("inference/inference.jl") -include("modeling_library/modeling_library.jl") +include("modeling_library/modeling_library.jl") \ No newline at end of file From 70d1c2d6c905e9dbd583dfcebe5cd2d5b17f8e98 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 13 Jul 2021 17:02:42 -0400 Subject: [PATCH 3/8] bug fix --- src/modeling_library/switch/switch.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index ad7ed67ec..03ad31e30 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -29,6 +29,8 @@ function (gen_fn::Switch{C})(index::C, args...) where C retval end +include("trace.jl") + function to_serializable_trace(tr::SwitchTrace) GenericST(to_serializable_trace(tr.branch), (tr.index, tr.retval, tr.args, tr.score, tr.noise)) end @@ -41,7 +43,6 @@ function from_serializable_trace(c::GenericST, gf::Switch) ) end -include("trace.jl") include("assess.jl") include("propose.jl") include("simulate.jl") From 33291de4eb5083e9a32a9b092486fd2ee7d55ace Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 13 Jul 2021 17:02:53 -0400 Subject: [PATCH 4/8] static serialization --- src/static_ir/static_ir.jl | 4 ++-- src/static_ir/trace.jl | 26 +++++++++++++++++++------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index b67776c3a..1eda20e8b 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -37,10 +37,10 @@ function generate_generative_function(ir::StaticIR, name::Symbol; track_diffs=fa end function generate_generative_function(ir::StaticIR, name::Symbol, options::StaticIRGenerativeFunctionOptions) + gen_fn_type_name = gensym("StaticGenFunction_$name") (trace_defns, trace_struct_name, tracefields) = generate_trace_type_and_methods(ir, name, options) - gen_fn_type_name = gensym("StaticGenFunction_$name") return_type = ir.return_node.typ trace_type = trace_struct_name has_argument_grads = tuple(map((node) -> node.compute_grad, ir.arg_nodes)...) @@ -64,7 +64,7 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati serialization_code = generate_serialization_methods(ir, trace_struct_name, gen_fn_type_name, tracefields) - Expr(:block, trace_defns, gen_fn_defn #=, serialization_code=#, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) + Expr(:block, trace_defns, gen_fn_defn, serialization_code, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) end include("print_ir.jl") diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 9cc1464e0..2469aa9f2 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -134,11 +134,13 @@ function generate_trace_struct(ir::StaticIR, trace_struct_name::Symbol, options: end function generate_serialization_methods(ir::StaticIR, trace_struct_name::Symbol, gen_fn_typename::Symbol, fields) - to_subtraces_exprs = [:(tr.$(field.fieldname)) for field in fields if field.holds_subtrace] + to_subtraces_exprs = [ + :($(GlobalRef(Gen, :to_serializable_trace))(tr.$(field.fieldname))) + for field in fields if field.holds_subtrace + ] to_properties_exprs = [:(tr.$(field.fieldname)) for field in fields if !field.holds_subtrace] # fields will have a bunch of properties, then the subtraces, then more properties - num_initial_props = 0 for field in fields if !field.holds_subtrace @@ -147,7 +149,8 @@ function generate_serialization_methods(ir::StaticIR, trace_struct_name::Symbol, break; end end - gen_fns = [QuoteNode(node.generative_function) for node in ir.call_nodes] + + gen_fns = [node.generative_function for node in ir.call_nodes] quote function $(GlobalRef(Gen, :to_serializable_trace))(tr::$trace_struct_name) @@ -156,11 +159,18 @@ function generate_serialization_methods(ir::StaticIR, trace_struct_name::Symbol, $(Expr(:tuple, to_properties_exprs...)) ) end - function $(GlobalRef(Gen, :from_serializable_trace))(st::$(GlobalRef(Gen, :GenericST)), gf::$gen_fn_typename) + function $(GlobalRef(Gen, :from_serializable_trace))( + st::$(GlobalRef(Gen, :GenericST)), + gf::$gen_fn_typename + ) return $trace_struct_name( st.properties[1:$num_initial_props]..., - ($(GlobalRef(Gen, :from_serializable_trace))(args...) for args in zip(st.subtraces, $gen_fns))..., - st.properties[$(num_initial_props + 1):end]... + ( + $(GlobalRef(Gen, :from_serializable_trace))(args...) + for args in zip(st.subtraces, $gen_fns) + )..., + st.properties[$(num_initial_props + 1):end]..., + gf ) end end @@ -340,7 +350,9 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_args_expr, get_retval_expr, get_choices_expr, get_schema_expr, get_values_shallow_expr, get_submaps_shallow_expr, static_get_value_exprs..., - static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...) + static_has_value_exprs..., static_get_submap_exprs..., + getindex_exprs... + ) (exprs, trace_struct_name, fields) end From 3648ceb3e670d3538bd62c1a919bb30c39e3e35c Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 13 Jul 2021 17:03:27 -0400 Subject: [PATCH 5/8] better test --- test/dsl/static_dsl.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/dsl/static_dsl.jl b/test/dsl/static_dsl.jl index 85aa3c689..91bf2b1e5 100644 --- a/test/dsl/static_dsl.jl +++ b/test/dsl/static_dsl.jl @@ -605,7 +605,12 @@ end @testset "serialization" begin tr = simulate(model, ([1., 2., 3., 4.],)) - @test serialize_loop_successful(tr) + @test Gen.to_serializable_trace(tr) isa Gen.GenericST + io = IOBuffer() + serialize_trace(io, tr) + seek(io, 0) + deserialized_tr = deserialize_trace(io, model) + @test get_choices(deserialized_tr) == get_choices(tr) end end # @testset "static DSL" From e6b99f6b4828c57e50c3c0e97cbf8136792fbf57 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 13 Jul 2021 17:04:45 -0400 Subject: [PATCH 6/8] GenericST -> GenericSerializableTrace --- src/dynamic/serialization.jl | 6 ++-- src/modeling_library/call_at/call_at.jl | 4 +-- src/modeling_library/choice_at/choice_at.jl | 4 +-- src/modeling_library/custom_determ.jl | 4 +-- src/modeling_library/recurse/recurse.jl | 4 +-- src/modeling_library/switch/switch.jl | 6 ++-- src/modeling_library/vector.jl | 4 +-- src/serialization.jl | 38 ++------------------- src/static_ir/trace.jl | 4 +-- test/dsl/static_dsl.jl | 2 +- 10 files changed, 22 insertions(+), 54 deletions(-) diff --git a/src/dynamic/serialization.jl b/src/dynamic/serialization.jl index 1d5a9c102..3d57b9432 100644 --- a/src/dynamic/serialization.jl +++ b/src/dynamic/serialization.jl @@ -18,7 +18,7 @@ function _trie_to_serializable(trie::Trie) triemap(trie, identity, _record_to_serializable) end function to_serializable_trace(tr::DynamicDSLTrace) - return GenericST( + return GenericSerializableTrace( _trie_to_serializable(tr.trie), (tr.isempty, tr.score, tr.noise, tr.args, tr.retval) ) @@ -29,9 +29,9 @@ end # we have to run the generative function to get access to this! mutable struct GFDeserializeState trace::DynamicDSLTrace - serialized::GenericST + serialized::GenericSerializableTrace end -function from_serializable_trace(st::GenericST, gen_fn::DynamicDSLFunction{T}) where T +function from_serializable_trace(st::GenericSerializableTrace, gen_fn::DynamicDSLFunction{T}) where T trace = DynamicDSLTrace{T}(gen_fn, Trie{Any, ChoiceOrCallRecord}(), st.properties...) state = GFDeserializeState(trace, st) exec(gen_fn, state, trace.args) diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 2ed9c5362..177aa645e 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -158,9 +158,9 @@ function accumulate_param_gradients!(trace::CallAtTrace, retval_grad) end function to_serializable_trace(tr::CallAtTrace) - return GenericST(to_serializable_trace(tr.subtrace), tr.key) + return GenericSerializableTrace(to_serializable_trace(tr.subtrace), tr.key) end -function from_serializable_trace(st::GenericST, gf::CallAtCombinator) +function from_serializable_trace(st::GenericSerializableTrace, gf::CallAtCombinator) return get_trace_type(gf)(gf, from_serializable_trace(st.subtraces, gf.kernel), st.properties) end diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl index f37b64bfe..b89afa73c 100644 --- a/src/modeling_library/choice_at/choice_at.jl +++ b/src/modeling_library/choice_at/choice_at.jl @@ -174,9 +174,9 @@ function accumulate_param_gradients!(trace::ChoiceAtTrace, retval_grad) end function to_serializable_trace(tr::ChoiceAtTrace) - return GenericST(nothing, (tr.value, tr.key, tr.kernel_args, tr.score)) + return GenericSerializableTrace(nothing, (tr.value, tr.key, tr.kernel_args, tr.score)) end -function from_serializable_trace(st::GenericST, gf::ChoiceAtCombinator) +function from_serializable_trace(st::GenericSerializableTrace, gf::ChoiceAtCombinator) return get_trace_type(gf)(gf, st.properties...) end diff --git a/src/modeling_library/custom_determ.jl b/src/modeling_library/custom_determ.jl index 74a2dedaf..fdb13bc3a 100644 --- a/src/modeling_library/custom_determ.jl +++ b/src/modeling_library/custom_determ.jl @@ -205,9 +205,9 @@ has_argument_grads(gen_fn::CustomUpdateGF) = tuple(fill(nothing, num_args(gen_fn apply_with_state(gen_fn::CustomUpdateGF, args) = error("not implemented") function to_serializable_trace(tr::CustomDetermGFTrace) - return GenericST(nothing, (tr.retval, tr.state, tr.args)) + return GenericSerializableTrace(nothing, (tr.retval, tr.state, tr.args)) end -function from_serializable_trace(st::GenericST, gf::CustomDetermGF) +function from_serializable_trace(st::GenericSerializableTrace, gf::CustomDetermGF) return get_trace_type(gf)(st.properties..., gf) end diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 6ffdb1cb2..73d1326fa 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -194,7 +194,7 @@ function get_aggregation_constraints(constraints::ChoiceMap, cur::Int) end function to_serializable_trace(tr::RecurseTrace) - return GenericST( + return GenericSerializableTrace( ( Dict(k => to_serializable_trace(subtr) for (k, subtr) in tr.production_traces), Dict(k => to_serializable_trace(subtr) for (k, subtr) in tr.aggregation_traces) @@ -202,7 +202,7 @@ function to_serializable_trace(tr::RecurseTrace) (tr.max_branch, tr.score, tr.root_idx, tr.num_has_choices) ) end -function from_serializable_trace(st::GenericST, gf::Recurse{S, T}) where {S, T} +function from_serializable_trace(st::GenericSerializableTrace, gf::Recurse{S, T}) where {S, T} production_traces = PersistentHashMap{Int, S}() for (k, subst) in st.subtraces[1] production_traces = assoc(production_traces, k, from_serializable_trace(subst, gf.production_kern)) diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 03ad31e30..05323f589 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -32,11 +32,11 @@ end include("trace.jl") function to_serializable_trace(tr::SwitchTrace) - GenericST(to_serializable_trace(tr.branch), (tr.index, tr.retval, tr.args, tr.score, tr.noise)) + GenericSerializableTrace(to_serializable_trace(tr.branch), (tr.index, tr.retval, tr.args, tr.score, tr.noise)) end -function from_serializable_trace(c::GenericST, gf::Switch) +function from_serializable_trace(c::GenericSerializableTrace, gf::Switch) (index, retval, args, score, noise) = c.properties - GenericST( + GenericSerializableTrace( gf, index, from_serializable_trace(c.subtraces, gf.branches[index]), retval, args, score, noise diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index bf4db68ce..2b34f2bbd 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -189,12 +189,12 @@ end # Serialization # ################# function to_serializable_trace(trace::VectorTrace) - GenericST( + GenericSerializableTrace( [to_serializable_trace(st) for st in trace.subtraces], (trace.retval, trace.args, trace.len, trace.num_nonempty, trace.score, trace.noise) ) end -function from_serializable_trace(st::GenericST, gf::GenerativeFunction{<:Any, VectorTrace{GenFnType, T, U}}) where {GenFnType, T, U} +function from_serializable_trace(st::GenericSerializableTrace, gf::GenerativeFunction{<:Any, VectorTrace{GenFnType, T, U}}) where {GenFnType, T, U} subtraces = PersistentVector{U}( [from_serializable_trace(serialized_subtrace, _gen_fn_at_addr(gf, i)) for (i, serialized_subtrace) in enumerate(st.subtraces)] diff --git a/src/serialization.jl b/src/serialization.jl index 1c8b4869d..a71ccdd3e 100644 --- a/src/serialization.jl +++ b/src/serialization.jl @@ -15,7 +15,7 @@ abstract type SerializableTrace end Get a SerializableTrace representing the `trace` in a serializable manner. """ function to_serializable_trace(trace::Trace) - return DefaultST(trace) + error("Not implemented") end """ @@ -27,38 +27,6 @@ function from_serializable_trace(::SerializableTrace, ::GenerativeFunction) error("Not implemented.") end -""" - DefaultST <: SerializableTrace - -A serializable trace which serializes by attempting to call `Base.serialize` -on the original trace object. - -Many trace types cannot be reliably serialized using this. -""" -struct DefaultST{T} <: SerializableTrace - trace::T - DefaultST(trace::T) where {T <: Trace} = new{T}(trace) -end -from_serializable_trace(st::DefaultST, ::GenerativeFunction) = st.trace - -# """ -# ChoiceMapST <: SerializableTrace - -# A serializable trace which encodes a choicemap, -# and uses `Gen.generate` with the encoded choicemap to deserialize. - -# This may not save untraced randomness in a trace. -# """ -# struct ChoiceMapST{A, C} <: SerializableTrace -# args::A -# cm::C -# ChoiceMapST(args::Tuple, cm::ChoiceMap) = new(args, cm) -# end -# function from_serializable_trace(st::ChoiceMapST, gf::GenerativeFunction) -# trace, _ = generate(gf, st.args, st.cm) -# return trace -# end - """ serialize_trace(stream::IO, trace::Trace) serialize_trace(filename::AbstractString, trace::Trace) @@ -81,12 +49,12 @@ function deserialize_trace(filename_or_io::Union{IO, AbstractString}, gf::Genera end """ - GenericST <: SerializableTrace + GenericSerializableTrace <: SerializableTrace A SerializableTrace which contains some subtraces which have been recursively converted to `SerializableTrace`s, and some properties which are directly serializable. """ -struct GenericST{S, P} <: SerializableTrace +struct GenericSerializableTrace{S, P} <: SerializableTrace subtraces::S properties::P end diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 2469aa9f2..4d1f80d92 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -154,13 +154,13 @@ function generate_serialization_methods(ir::StaticIR, trace_struct_name::Symbol, quote function $(GlobalRef(Gen, :to_serializable_trace))(tr::$trace_struct_name) - return $(GlobalRef(Gen, :GenericST))( + return $(GlobalRef(Gen, :GenericSerializableTrace))( $(Expr(:tuple, to_subtraces_exprs...)), $(Expr(:tuple, to_properties_exprs...)) ) end function $(GlobalRef(Gen, :from_serializable_trace))( - st::$(GlobalRef(Gen, :GenericST)), + st::$(GlobalRef(Gen, :GenericSerializableTrace)), gf::$gen_fn_typename ) return $trace_struct_name( diff --git a/test/dsl/static_dsl.jl b/test/dsl/static_dsl.jl index 91bf2b1e5..077f4bdbe 100644 --- a/test/dsl/static_dsl.jl +++ b/test/dsl/static_dsl.jl @@ -605,7 +605,7 @@ end @testset "serialization" begin tr = simulate(model, ([1., 2., 3., 4.],)) - @test Gen.to_serializable_trace(tr) isa Gen.GenericST + @test Gen.to_serializable_trace(tr) isa Gen.GenericSerializableTrace io = IOBuffer() serialize_trace(io, tr) seek(io, 0) From ea65983186206d8eafeb831b36a6de0557dda254 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 13 Jul 2021 17:14:17 -0400 Subject: [PATCH 7/8] documentation --- docs/src/ref/extending.md | 11 +++++++++++ docs/src/ref/gfi.md | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/docs/src/ref/extending.md b/docs/src/ref/extending.md index 7f9dfd480..8e28fc27c 100644 --- a/docs/src/ref/extending.md +++ b/docs/src/ref/extending.md @@ -228,6 +228,17 @@ If your generative function has trainable parameters, then implement: - [`accumulate_param_gradients!`](@ref) +#### Supporting trace serialization +To support trace serialization, a trace type of type `T` for a generative function of type `G` must convertable into a `SerializableTrace` object, and must be recoverable from a `SerializableTrace` object and the generative function. +```@docs +SerializableTrace +to_serializable_trace +from_serializable_trace +``` +A user must implement `to_serializable_Trace(::T)`, and `from_serializable_Trace(::ST, ::G)` for some concrete type `ST <: SerializableTrace`. This may be a custom type, or the user may use the built-in type +```@docs +GenericSerializableTrace +``` ## Custom modeling languages diff --git a/docs/src/ref/gfi.md b/docs/src/ref/gfi.md index dde929722..d3da5ff26 100644 --- a/docs/src/ref/gfi.md +++ b/docs/src/ref/gfi.md @@ -350,6 +350,17 @@ The set of elements (either arguments, random choices, or trainable parameters) If the return value of the function is conditionally dependent on any element in the gradient source set given the arguments and values of all other random choices, for all possible traces of the function, then the generative function requires a *return value gradient* to compute gradients with respect to elements of the gradient source set. This static property of the generative function is reported by [`accepts_output_grad`](@ref). +## Serialization +To serialize a trace `tr` for a generative function `gf` +(stave the trace to disk), a user may call +```julia +serialize_trace(filename_or_io::Union{IO, AbstractString}, tr) +``` +To recover the trace, a user may call +```julia +deserialized_tr = deserialize_trace(filename_or_io, gf) +``` + ## Generative function interface The complete set of methods in the generative function interface (GFI) is: From 7f0f7ef485216613cf19e0fa7122fec3d7f52bbc Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 13 Jul 2021 17:14:31 -0400 Subject: [PATCH 8/8] Serialization dep --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 538b118e9..31ae909cd 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat]