diff --git a/docs/make.jl b/docs/make.jl index ff4e5b60b..e959a3bc7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -14,9 +14,11 @@ makedocs( "Introduction" => "index.md", "Library" => [ "Factor nodes" => "lib/nodes.md", + "Variables" => "lib/variables.md", "Messages" => "lib/message.md", "Marginals" => "lib/marginal.md", "Message update rules" => "lib/rules.md", + "Callbacks" => "lib/callbacks.md", "Helper utils" => "lib/helpers.md", "Algebra utils" => "lib/algebra.md", "Specific factor nodes" => [ diff --git a/docs/src/lib/callbacks.md b/docs/src/lib/callbacks.md new file mode 100644 index 000000000..23e0bcd56 --- /dev/null +++ b/docs/src/lib/callbacks.md @@ -0,0 +1,27 @@ +# Callbacks in the Message Passing Procedure + +ReactiveMP provides a way to "hook" into the message passing procedure and listen to various events +via "callbacks". This can be useful, for example, to debug messages or monitor the order of computations. + +```@docs +ReactiveMP.invoke_callback +ReactiveMP.merge_callbacks +ReactiveMP.MergedCallbacks +``` + +## All defined events + +Here is the list of predefined event types, to which a custom callback handler can react to. + +```@docs +ReactiveMP.BeforeMessageRuleCallback +ReactiveMP.AfterMessageRuleCallback +ReactiveMP.BeforeProductOfTwoMessages +ReactiveMP.AfterProductOfTwoMessages +ReactiveMP.BeforeProductOfMessages +ReactiveMP.AfterProductOfMessages +ReactiveMP.BeforeFormConstraintApplied +ReactiveMP.AfterFormConstraintApplied +ReactiveMP.BeforeMarginalComputation +ReactiveMP.AfterMarginalComputation +``` diff --git a/docs/src/lib/message.md b/docs/src/lib/message.md index 6098c4854..036f03330 100644 --- a/docs/src/lib/message.md +++ b/docs/src/lib/message.md @@ -73,13 +73,15 @@ is_clamped(message), is_initial(message) ### [Product of messages](@id lib-messages-product) In message passing framework, in order to compute a posterior we must compute a normalized product of two messages. -For this purpose the `ReactiveMP.jl` uses the `multiply_messages` function, which internally uses the `prod` function -defined in `BayesBase.jl` with various product strategies. We refer an interested reader to the documentation of the -`BayesBase.jl` package for more information. +For this purpose the `ReactiveMP.jl` uses the [`ReactiveMP.MessageProductContext`](@ref) structure, together with the [`ReactiveMP.compute_product_of_messages`](@ref) and [`ReactiveMP.compute_product_of_two_messages`](@ref) functions. Both functions accept a [`ReactiveMP.AbstractVariable`](@ref) as the first argument to identify which variable the product is being computed for — this is useful for callbacks (e.g. [`ReactiveMP.BeforeProductOfTwoMessages`](@ref)). The [`ReactiveMP.compute_product_of_two_messages`](@ref) function internally uses the `prod` function +defined in `BayesBase.jl` with various product strategies. We refer an interested reader to the documentation of the `BayesBase.jl` package for more information. ```@docs -ReactiveMP.multiply_messages -ReactiveMP.messages_prod_fn +ReactiveMP.MessageProductContext +ReactiveMP.compute_product_of_two_messages +ReactiveMP.compute_product_of_messages +ReactiveMP.MessagesProductFromLeftToRight +ReactiveMP.MessagesProductFromRightToLeft ``` @@ -95,4 +97,4 @@ A *message mapping* defines how messages are transformed or mapped during the pr ```@docs ReactiveMP.MessageMapping -``` \ No newline at end of file +``` diff --git a/docs/src/lib/nodes.md b/docs/src/lib/nodes.md index e1c97346c..a726c11ed 100644 --- a/docs/src/lib/nodes.md +++ b/docs/src/lib/nodes.md @@ -113,27 +113,6 @@ ReactiveMP.RequireMarginalFunctionalDependencies ReactiveMP.RequireEverythingFunctionalDependencies ``` -## [Customizing Dependencies with Metadata](@id lib-node-metadata-dependencies) - -The functional dependencies of a node can be customized at runtime using options during node activation. This allows for runtime customization of the functional dependencies, e.g. to test different message passing schemes or implement specialized behavior for specific instances of a node type: - -```julia -# Define custom dependencies based on metadata -function ReactiveMP.collect_functional_dependencies(::Type{MyNode}, options::FactorNodeActivationOptions) - if some_condition(options) # a user can specify dependencies based, for example, on metadata - return CustomDependencies() - end - # Fall back to default dependencies - return ReactiveMP.collect_functional_dependencies(MyNode, getdependecies(options)) -end - -# Use custom dependencies during activation -node = factornode(MyNode, ...) -activate!(node, FactorNodeActivationOptions(:custom_behavior, ...)) -``` - -This feature is particularly useful for testing different message passing schemes or implementing specialized behavior for specific instances of a node type. - ## [Node traits](@id lib-node-traits) Each factor node has to define the [`ReactiveMP.is_predefined_node`](@ref) trait function and to specify a [`ReactiveMP.PredefinedNodeFunctionalForm`](@ref) diff --git a/docs/src/lib/variables.md b/docs/src/lib/variables.md new file mode 100644 index 000000000..49f37bef3 --- /dev/null +++ b/docs/src/lib/variables.md @@ -0,0 +1,36 @@ + +# [Variables](@id lib-variables) + +Variables are fundamental building blocks of a factor graph. Each variable represents either a latent quantity to be inferred, an observed data point, or a fixed constant. All variable types are subtypes of [`ReactiveMP.AbstractVariable`](@ref). + +```@docs +ReactiveMP.AbstractVariable +``` + +## [Random variables](@id lib-variables-random) + +Random variables represent latent (unobserved) quantities in the model. During inference, messages flow through them to update the marginal belief. + +```@docs +ReactiveMP.RandomVariable +ReactiveMP.randomvar +``` + +## [Data variables](@id lib-variables-data) + +Data variables represent observed quantities. Their value is not fixed at creation time and can be updated later via [`update!`](@ref). + +```@docs +ReactiveMP.DataVariable +ReactiveMP.datavar +ReactiveMP.update! +``` + +## [Constant variables](@id lib-variables-constant) + +Constant variables hold a fixed value, wrapped in a `PointMass` distribution. Messages from constant variables are always marked as clamped. + +```@docs +ReactiveMP.ConstVariable +ReactiveMP.constvar +``` diff --git a/ext/ReactiveMPProjectionExt/layout/cvi_projection.jl b/ext/ReactiveMPProjectionExt/layout/cvi_projection.jl index d5e3d9c07..1671be72e 100644 --- a/ext/ReactiveMPProjectionExt/layout/cvi_projection.jl +++ b/ext/ReactiveMPProjectionExt/layout/cvi_projection.jl @@ -52,6 +52,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), @@ -62,6 +63,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) end @@ -75,6 +77,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), @@ -85,6 +88,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) end @@ -98,6 +102,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) let interface = factornode.out msgs_names = Val{(:out,)}() @@ -125,6 +130,7 @@ function deltafn_apply_layout( addons, factornode, rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap @@ -152,6 +158,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), @@ -162,5 +169,6 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) end diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index 3dc79f1ca..24897c31a 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -21,6 +21,8 @@ include("helpers/algebra/standard_basis_vector.jl") include("constraints/form.jl") +include("callbacks.jl") +include("variable.jl") include("message.jl") include("marginal.jl") include("addons.jl") @@ -76,7 +78,7 @@ include("approximations/cvi_projection.jl") # Equality node is a special case and needs to be included before random variable implementation include("nodes/equality.jl") -include("variables/variable.jl") +include("variables/generic.jl") include("variables/random.jl") include("variables/constant.jl") include("variables/data.jl") @@ -113,6 +115,22 @@ function __init__() """ println(io, errmsg) end + if exc.f === ReactiveMP.invoke_callback && length(argtypes) >= 2 + event_type = length(argtypes) >= 2 ? argtypes[2] : "unknown" + errmsg = """ + + `ReactiveMP.invoke_callback` was called with a callback handler of type `$(argtypes[1])` for event `$(event_type)`, but no matching method was found. This can happen if: + + 1. You implemented a custom callback handler but forgot to define `invoke_callback` for this specific event type. + Make sure your handler has a method like: + ReactiveMP.invoke_callback(::$(argtypes[1]), ::$(event_type), args...) = ... + + 2. You meant to pass a `NamedTuple` as the callbacks handler but forgot the trailing comma. + In Julia, `(key = value)` is parsed as a plain assignment, not a NamedTuple. + Use `(key = value,)` (with a trailing comma) instead. + """ + println(io, errmsg) + end end end diff --git a/src/approximations/unscented.jl b/src/approximations/unscented.jl index 8837b4815..a2b8be33c 100644 --- a/src/approximations/unscented.jl +++ b/src/approximations/unscented.jl @@ -135,13 +135,13 @@ function statistic_estimation( # Compute `C_tilde` only if `C === true` @inbounds C_tilde = if C reshape( - sum( - wc * (xi - m) * (yi - m_tilde) for - (wc, xi, yi) in zip(weights_c, sigma_points, g_sigma) - ), - 1, - d_out, - ) + sum( + wc * (xi - m) * (yi - m_tilde) for + (wc, xi, yi) in zip(weights_c, sigma_points, g_sigma) + ), + 1, + d_out, + ) else nothing end @@ -224,10 +224,10 @@ function unscented_statistics( # Compute `C_tilde` only if `C === true` @inbounds C_tilde = if C sum( - weights_c[k + 1] * - (sigma_points[k + 1] - m) * - (g_sigma[k + 1] - m_tilde)' for k in 0:(2d) - ) + weights_c[k + 1] * + (sigma_points[k + 1] - m) * + (g_sigma[k + 1] - m_tilde)' for k in 0:(2d) + ) else nothing end @@ -258,10 +258,10 @@ function unscented_statistics( # Compute `C_tilde` only if `C === true` @inbounds C_tilde = if C sum( - weights_c[k + 1] * - (sigma_points[k + 1] - m) * - (g_sigma[k + 1] - m_tilde)' for k in 0:(2d) - ) + weights_c[k + 1] * + (sigma_points[k + 1] - m) * + (g_sigma[k + 1] - m_tilde)' for k in 0:(2d) + ) else nothing end diff --git a/src/callbacks.jl b/src/callbacks.jl new file mode 100644 index 000000000..f52cd1c5a --- /dev/null +++ b/src/callbacks.jl @@ -0,0 +1,297 @@ + +""" + invoke_callback(callbacks, event, args...) + +Custom callbacks handlers should implement `invoke_callback` in order to listen to events +during the reactive message passing procedure. +A typical event has type `Val{...}`, e.g. `Val{:before_message_rule_call}`. +Does nothing if `callbacks` is `nothing`. + +```jldoctest +julia> struct MyCustomCallbackHandler end; + +julia> ReactiveMP.invoke_callback(::MyCustomCallbackHandler, event, args...) = print("Event \$(event) has been called with \$(args)"); +``` + +See also: [`ReactiveMP.merge_callbacks`](@ref) +""" +function invoke_callback(callbacks::Nothing, args...) + return nothing +end + +""" + invoke_callback(callbacks::NamedTuple, event::Val{...}, args...) + +The `callbacks` can also be a `NamedTuple` with fields corresponding to event names, e.g. + +```jldoctest +julia> callbacks = (before_message_rule_call = (args...) -> sum(args),); + +julia> ReactiveMP.invoke_callback(callbacks , Val{:before_message_rule_call}(), 1, 2) +3 + +julia> ReactiveMP.invoke_callback(callbacks , Val{:other_event}(), 1, 2, 3) +``` + +If the `NamedTuple` does not have a field corresponding to the event name, the event will be ignored. +""" +function invoke_callback( + callbacks::NamedTuple{K}, ::Val{E}, args... +) where {K, E} + if E in K + return callbacks[E](args...) + end + return nothing +end + +""" + MergedCallbacks{F, C}(reduce_fn, callbacks) + +The result of the [`ReactiveMP.merge_callbacks`](@ref) procedure. +""" +struct MergedCallbacks{F, C} + reduce_fn::F + callbacks::C +end + +""" + merge_callbacks(callbacks_handlers...; reduce_fn = nothing) + +This function accept an arbitrary amount of callback handlers and merges them together. +Some callback handlers may or may not react on certain type of events. + +```jldoctest +julia> handler1 = (event1 = (args...) -> println("Event 1 from handler 1"), event2 = (args...) -> println("Event 2 from handler 1")); + +julia> handler2 = (event1 = (args...) -> println("Event 1 from handler 2"),); + +julia> merged_handler = ReactiveMP.merge_callbacks(handler1, handler2); + +julia> ReactiveMP.invoke_callback(merged_handler, Val(:event1)); +Event 1 from handler 1 +Event 1 from handler 2 + +julia> ReactiveMP.invoke_callback(merged_handler, Val(:event2)); +Event 2 from handler 1 +``` + +If `reduce_fn` is not `nothing`, the result of all the callbacks will be reduced +with the provided reduce function. + +```jldoctest +julia> callback_handler1 = (event1 = (a, b) -> a + b,); + +julia> callback_handler2 = (event1 = (a, b) -> a * b,); + +julia> merged_handler = ReactiveMP.merge_callbacks(callback_handler1, callback_handler2); + +julia> ReactiveMP.invoke_callback(merged_handler, Val(:event1), 2, 3) +(5, 6) + +julia> merged_handler_with_reduce = ReactiveMP.merge_callbacks(callback_handler1, callback_handler2; reduce_fn = +); + +julia> ReactiveMP.invoke_callback(merged_handler_with_reduce, Val(:event1), 2, 3) +11 +``` + +The `reduce_fn` can also be a `NamedTuple` that sets different reduce functions for +different events. + +```jldoctest +julia> callback_handler1 = (event1 = (a, b) -> a + b, event2 = (a, b) -> a - b); + +julia> callback_handler2 = (event1 = (a, b) -> a * b, event2 = (a, b) -> a / b); + +julia> merged_handler = ReactiveMP.merge_callbacks(callback_handler1, callback_handler2; reduce_fn = ( + event1 = +, + event2 = * + )); + +julia> ReactiveMP.invoke_callback(merged_handler, Val(:event1), 4, 5) +29 + +julia> ReactiveMP.invoke_callback(merged_handler, Val(:event2), 5, 5) +0.0 +``` + +See also: [`ReactiveMP.invoke_callback`](@ref) +""" +function merge_callbacks(callback_handlers...; reduce_fn = nothing) + return MergedCallbacks(reduce_fn, callback_handlers) +end + +""" + invoke_callback(merged::MergedCallbacks, event, args...) + +A specialized version of [`ReactiveMP.invoke_callback`](@ref) for [`ReactiveMP.MergedCallbacks`](@ref). +Calls the provided callbacks in order and uses the provided reduce function to +reduce the collection of results into a single one. +""" +function invoke_callback(merged::MergedCallbacks, event, args...) + result = map(merged.callbacks) do callback + invoke_callback(callback, event, args...) + end + return merged_callback_reduce_result(merged.reduce_fn, event, result) +end + +merged_callback_reduce_result(::Nothing, _, result) = result +merged_callback_reduce_result(reduce_fn::F, _, result) where {F} = reduce( + reduce_fn, result +) +# If `reduce_fn` is a NamedTuple, then we choose a specific function for a specific event from this tuple +merged_callback_reduce_result(reduce_fn::NamedTuple{K}, event::Val{E}, result) where {K, E} = merged_callback_reduce_result( + get(reduce_fn, E, nothing), event, result +) + +# All defined events go here, so its easier to document them all in one place + +""" + BeforeMessageRuleCallback # Val{:before_message_rule_call} + +Alias for `Val{:before_message_rule_call}`. This event is being used to call a callback right +before computing the message and calling the corresponding rule. The callback handler for this event +should accept the following positional arguments: +- `mapping` of type [`ReactiveMP.MessageMapping`](@ref), contains information about the node type, etc +- `messages`, typically of type `Tuple` if present, `nothing` otherwise +- `marginals`, typically of type `Tuple` if present, `nothing` otherwise + +```jldoctest +julia> import ReactiveMP: BeforeMessageRuleCallback + +julia> struct MyCallbackHandler end + +julia> ReactiveMP.invoke_callback(::MyCallbackHandler, ::BeforeMessageRuleCallback, mapping, messages, marginals) = println("Before message called!") +``` + +See also: [`ReactiveMP.invoke_callback`](@ref) +""" +const BeforeMessageRuleCallback = Val{:before_message_rule_call} + +""" + AfterMessageRuleCallback # Val{:after_message_rule_call} + +Alias for `Val{:after_message_rule_call}`. This event is being used to call a callback right +after computing the message and calling the corresponding rule. The callback handler for this event +should accept the following positional arguments: +- `mapping` of type [`ReactiveMP.MessageMapping`](@ref), contains information about the node type, etc +- `messages`, typically of type `Tuple` if present, `nothing` otherwise +- `marginals`, typically of type `Tuple` if present, `nothing` otherwise +- `result`, the result of the rule invocation (of `rulefallback`), can be any type +- `addons`, the result of the addons invocation, if present, can be any type + +```jldoctest +julia> import ReactiveMP: AfterMessageRuleCallback + +julia> struct MyCallbackHandler end + +julia> ReactiveMP.invoke_callback(::MyCallbackHandler, ::AfterMessageRuleCallback, mapping, messages, marginals, result, addons) = println("After message called!") +``` +""" +const AfterMessageRuleCallback = Val{:after_message_rule_call} + +""" + BeforeProductOfTwoMessages # Val{:before_product_of_two_messages} + +Alias for `Val{:before_product_of_two_messages}`. This event is being used to call a callback right +before computing the product of two messages. The callback handler for this event +should accept the following positional arguments: +- `variable` of type [`ReactiveMP.AbstractVariable`](@ref) +- `context` of type [`ReactiveMP.MessageProductContext`](@ref) +- `left` of type [`ReactiveMP.Message`](@ref), the left-hand side message in the product +- `right` of type [`ReactiveMP.Message`](@ref), the right-hand side message in the product +""" +const BeforeProductOfTwoMessages = Val{:before_product_of_two_messages} + +""" + AfterProductOfTwoMessages # Val{:after_product_of_two_messages} + +Alias for `Val{:after_product_of_two_messages}`. This event is being used to call a callback right +after computing the product of two messages. The callback handler for this event +should accept the following positional arguments: +- `variable` of type [`ReactiveMP.AbstractVariable`](@ref) +- `context` of type [`ReactiveMP.MessageProductContext`](@ref) +- `left` of type [`ReactiveMP.Message`](@ref), the left-hand side message in the product +- `right` of type [`ReactiveMP.Message`](@ref), the right-hand side message in the product +- `result` of type [`ReactiveMP.Message`](@ref), the resulting message from the product +- `addons`, the computed addons for the result (can be `nothing`) +""" +const AfterProductOfTwoMessages = Val{:after_product_of_two_messages} + +""" + BeforeProductOfMessages # Val{:before_product_of_messages} + +Alias for `Val{:before_product_of_messages}`. This event is being used to call a callback right +before computing the product of a collection of messages (i.e. at the beginning of [`ReactiveMP.compute_product_of_messages`](@ref)). +The callback handler for this event should accept the following positional arguments: +- `variable` of type [`ReactiveMP.AbstractVariable`](@ref) +- `context` of type [`ReactiveMP.MessageProductContext`](@ref) +- `messages`, the collection of messages to be multiplied +""" +const BeforeProductOfMessages = Val{:before_product_of_messages} + +""" + AfterProductOfMessages # Val{:after_product_of_messages} + +Alias for `Val{:after_product_of_messages}`. This event is being used to call a callback right +after computing the product of a collection of messages (i.e. at the end of [`ReactiveMP.compute_product_of_messages`](@ref)). +The callback handler for this event should accept the following positional arguments: +- `variable` of type [`ReactiveMP.AbstractVariable`](@ref) +- `context` of type [`ReactiveMP.MessageProductContext`](@ref) +- `messages`, the original collection of messages that were multiplied +- `result` of type [`ReactiveMP.Message`](@ref), the final result after folding and form constraint application +""" +const AfterProductOfMessages = Val{:after_product_of_messages} + +""" + BeforeFormConstraintApplied # Val{:before_form_constraint_applied} + +Alias for `Val{:before_form_constraint_applied}`. This event is being used to call a callback right +before applying the form constraint via [`ReactiveMP.constrain_form`](@ref). Fires in both +[`ReactiveMP.FormConstraintCheckEach`](@ref) and [`ReactiveMP.FormConstraintCheckLast`](@ref) strategies. +The callback handler for this event should accept the following positional arguments: +- `variable` of type [`ReactiveMP.AbstractVariable`](@ref) +- `context` of type [`ReactiveMP.MessageProductContext`](@ref) +- `strategy`, the form constraint check strategy being used (e.g. [`ReactiveMP.FormConstraintCheckEach`](@ref) or [`ReactiveMP.FormConstraintCheckLast`](@ref)) +- `distribution`, the distribution about to be constrained +""" +const BeforeFormConstraintApplied = Val{:before_form_constraint_applied} + +""" + AfterFormConstraintApplied # Val{:after_form_constraint_applied} + +Alias for `Val{:after_form_constraint_applied}`. This event is being used to call a callback right +after applying the form constraint via [`ReactiveMP.constrain_form`](@ref). Fires in both +[`ReactiveMP.FormConstraintCheckEach`](@ref) and [`ReactiveMP.FormConstraintCheckLast`](@ref) strategies. +The callback handler for this event should accept the following positional arguments: +- `variable` of type [`ReactiveMP.AbstractVariable`](@ref) +- `context` of type [`ReactiveMP.MessageProductContext`](@ref) +- `strategy`, the form constraint check strategy being used (e.g. [`ReactiveMP.FormConstraintCheckEach`](@ref) or [`ReactiveMP.FormConstraintCheckLast`](@ref)) +- `distribution`, the distribution before the constraint was applied +- `result`, the distribution after the constraint was applied +""" +const AfterFormConstraintApplied = Val{:after_form_constraint_applied} + +""" + BeforeMarginalComputation # Val{:before_marginal_computation} + +Alias for `Val{:before_marginal_computation}`. This event fires right before computing the marginal +for a [`ReactiveMP.RandomVariable`](@ref) from its incoming messages. +The callback handler for this event should accept the following positional arguments: +- `variable` of type [`ReactiveMP.RandomVariable`](@ref) +- `context` of type [`ReactiveMP.MessageProductContext`](@ref) +- `messages`, the collection of incoming messages used to compute the marginal +""" +const BeforeMarginalComputation = Val{:before_marginal_computation} + +""" + AfterMarginalComputation # Val{:after_marginal_computation} + +Alias for `Val{:after_marginal_computation}`. This event fires right after computing the marginal +for a [`ReactiveMP.RandomVariable`](@ref) from its incoming messages. +The callback handler for this event should accept the following positional arguments: +- `variable` of type [`ReactiveMP.RandomVariable`](@ref) +- `context` of type [`ReactiveMP.MessageProductContext`](@ref) +- `messages`, the collection of incoming messages used to compute the marginal +- `result`, the computed marginal +""" +const AfterMarginalComputation = Val{:after_marginal_computation} diff --git a/src/constraints/form.jl b/src/constraints/form.jl index df8996442..968d81bb1 100644 --- a/src/constraints/form.jl +++ b/src/constraints/form.jl @@ -9,72 +9,84 @@ using TupleTools import BayesBase: resolve_prod_strategy import Base: + -# Form constraints are preserved during execution of the `prod` function -# There are two major strategies to check current functional form -# We may check and preserve functional form of the result of the `prod` function -# after each subsequent `prod` -# or we may want to wait after all `prod` functions in the equality chain have been executed +# Form constraints control the functional form of messages during the product computation. +# There are two strategies for when to apply the constraint: +# - `FormConstraintCheckEach`: apply after each pairwise `prod` in `compute_product_of_two_messages` +# - `FormConstraintCheckLast`: apply once at the end in `compute_product_of_messages` """ AbstractFormConstraint -Every functional form constraint is a subtype of `AbstractFormConstraint` abstract type. +Abstract supertype for all form constraints. Subtype this to create custom form constraints +that can be used with [`constrain_form`](@ref) and [`ReactiveMP.MessageProductContext`](@ref). -Note: this is not strictly necessary, but it makes automatic dispatch easier and compatible with the `CompositeFormConstraint`. +Not strictly required (any object works via [`ReactiveMP.WrappedFormConstraint`](@ref)), +but makes dispatch easier and is needed for [`ReactiveMP.CompositeFormConstraint`](@ref) composition via `+`. """ abstract type AbstractFormConstraint end """ FormConstraintCheckEach -This form constraint check strategy checks functional form of the messages product after each product in an equality chain. -Usually if a variable has been connected to multiple nodes we want to perform multiple `prod` to obtain a posterior marginal. -With this form check strategy `constrain_form` function will be executed after each subsequent `prod` function. +Form constraint check strategy that applies [`constrain_form`](@ref) after **each** pairwise product +inside [`ReactiveMP.compute_product_of_two_messages`](@ref). Use this when intermediate results +need to stay in a specific functional form (e.g. to prevent numerical issues during long product chains). + +See also: [`FormConstraintCheckLast`](@ref), [`ReactiveMP.MessageProductContext`](@ref) """ struct FormConstraintCheckEach end """ - FormConstraintCheckEach + FormConstraintCheckLast -This form constraint check strategy checks functional form of the last messages product in the equality chain. -Usually if a variable has been connected to multiple nodes we want to perform multiple `prod` to obtain a posterior marginal. -With this form check strategy `constrain_form` function will be executed only once after all subsequenct `prod` functions have been executed. +Form constraint check strategy that applies [`constrain_form`](@ref) **once** at the very end +of [`ReactiveMP.compute_product_of_messages`](@ref), after all pairwise products have been folded. +This is the default strategy and is more efficient when intermediate form doesn't matter. + +See also: [`FormConstraintCheckEach`](@ref), [`ReactiveMP.MessageProductContext`](@ref) """ struct FormConstraintCheckLast end """ FormConstraintCheckPickDefault -This form constraint check strategy simply fallbacks to a default check strategy for a given form constraint. +A meta-strategy that defers to the default check strategy of the given form constraint, +as defined by [`default_form_check_strategy`](@ref). """ struct FormConstraintCheckPickDefault end """ default_form_check_strategy(form_constraint) -Returns a default check strategy (e.g. `FormConstraintCheckEach` or `FormConstraintCheckEach`) for a given form constraint object. +Returns the default check strategy (either [`FormConstraintCheckEach`](@ref) or [`FormConstraintCheckLast`](@ref)) +for a given form constraint. Override this for custom constraints to control when they are applied. """ function default_form_check_strategy end """ default_prod_constraint(form_constraint) -Returns a default prod constraint needed to apply a given `form_constraint`. For most form constraints this function returns `ProdGeneric`. +Returns the default product strategy needed to apply a given `form_constraint`. +For most form constraints this returns `BayesBase.GenericProd()`. """ function default_prod_constraint end """ - constrain_form(constraint, something) + constrain_form(constraint, distribution) + +Applies the form `constraint` to `distribution` and returns the constrained result. +This is the main extension point for custom form constraints — implement a method of this function +for your constraint type and the distribution types you want to support. -This function applies a given form constraint to a given object. +See also: [`AbstractFormConstraint`](@ref), [`ReactiveMP.MessageProductContext`](@ref) """ function constrain_form end """ UnspecifiedFormConstraint -One of the form constraint objects. Does not imply any form constraints and simply returns the same object as receives. -However it does not allow `DistProduct` to be a valid functional form in the inference backend. +The default form constraint — does nothing and returns the distribution as-is. +Used when no form constraint has been specified in the [`ReactiveMP.MessageProductContext`](@ref). """ struct UnspecifiedFormConstraint <: AbstractFormConstraint end @@ -87,9 +99,10 @@ constrain_form(::UnspecifiedFormConstraint, something) = something """ WrappedFormConstraint(constraint, context) -This is a wrapper for a form constraint object. It allows to pass additional context to the `constrain_form` function. -By default all objects that are not sub-typed from `AbstractFormConstraint` are wrapped into this object. -Use `ReactiveMP.prepare_context` to provide an extra context for a given form constraint, that can be reused between multiple `constrain_form` calls. +A wrapper that pairs a form constraint with an optional precomputed context. +Any object that is not a subtype of [`AbstractFormConstraint`](@ref) gets automatically wrapped +into this during [`ReactiveMP.preprocess_form_constraints`](@ref). +Use [`ReactiveMP.prepare_context`](@ref) to provide extra context that can be reused across multiple [`constrain_form`](@ref) calls. """ struct WrappedFormConstraint{C, X} <: AbstractFormConstraint constraint::C @@ -101,15 +114,16 @@ struct WrappedFormConstraintNoContext end """ prepare_context(constraint) -This function prepares a context for a given form constraint. Returns `WrappedFormConstraintNoContext` if no context is needed (the default behaviour). +Prepares a reusable context for a given form constraint. Returns `WrappedFormConstraintNoContext` by default (i.e. no context needed). +Override this to precompute things that should be shared across multiple [`constrain_form`](@ref) calls. """ prepare_context(constraint) = WrappedFormConstraintNoContext() """ constrain_form(wrapped::WrappedFormConstraint, something) -This function unwraps the `wrapped` object and calls `constrain_form` function with the provided context. -If the context is not provided, simply calls `constrain_form` with the wrapped constraint. Otherwise passes the context to the `constrain_form` function as the second argument. +Unwraps the constraint and delegates to [`constrain_form`](@ref) with the inner constraint. +If a context was provided via [`ReactiveMP.prepare_context`](@ref), it is passed as the second argument. """ constrain_form(wrapped::WrappedFormConstraint, something) = constrain_form( wrapped, wrapped.context, something @@ -131,8 +145,9 @@ default_prod_constraint(wrapped::WrappedFormConstraint) = default_prod_constrain """ preprocess_form_constraints(constraints) -This function preprocesses form constraints and converts the provided objects into a form compatible with ReactiveMP inference backend (if possible). -If a tuple of constraints is passed, it creates a `CompositeFormConstraint` object. Wraps unknown form constraints into a `WrappedFormConstraint` object. +Converts form constraints into a form compatible with the ReactiveMP inference backend. +A tuple of constraints becomes a [`ReactiveMP.CompositeFormConstraint`](@ref). +Objects that are not subtypes of [`AbstractFormConstraint`](@ref) get wrapped into a [`ReactiveMP.WrappedFormConstraint`](@ref). """ function preprocess_form_constraints end @@ -147,7 +162,9 @@ preprocess_form_constraints(constraint) = WrappedFormConstraint( """ CompositeFormConstraint -Creates a composite form constraint that applies form constraints in order. The composed form constraints must be compatible and have the exact same `form_check_strategy`. +A form constraint that chains multiple constraints together, applying them in order via [`constrain_form`](@ref). +Create one by combining constraints with `+` (e.g. `constraint_a + constraint_b`). +All composed constraints must share the same [`default_form_check_strategy`](@ref). """ struct CompositeFormConstraint{C} <: AbstractFormConstraint constraints::C diff --git a/src/message.jl b/src/message.jl index 468815e75..fb627f60c 100644 --- a/src/message.jl +++ b/src/message.jl @@ -113,12 +113,58 @@ function Base.:(==)(left::Message, right::Message) end """ - multiply_messages(prod_strategy, left::Message, right::Message) + MessageProductContext(kwargs...) + +The structure that defines the context for the product of **two** messages within ReactiveMP. +The product is executed with the [`ReactiveMP.compute_product_of_messages`](@ref) function and +uses the `BayesBase.prod` under the hood. See BayesBase product API documentation for detailed description. + +The following `kwargs` are supported: +- `prod_constraint`: defines the first argument for the `BayesBase.prod` function (default is `BayesBase.GenericProd`) +- `form_constraint`: defines the form constraint to be applied on the result of computation, default is [`ReactiveMP.UnspecifiedFormConstraint`](@ref) +- `form_constraint_check_strategy`: defines the strategy to check the specified form constraint, either [`ReactiveMP.FormConstraintCheckLast`](@ref) or [`ReactiveMP.FormConstraintCheckEach`](@ref), default is [`ReactiveMP.FormConstraintCheckLast`](@ref) + + [`ReactiveMP.FormConstraintCheckLast`](@ref) will only call [`ReactiveMP.constrain_form`](@ref) at the end of the `[ReactiveMP.compute_product_of_messages]` + + [`ReactiveMP.FormConstraintCheckEach`](@ref) will call [`ReactiveMP.constrain_form`](@ref) at each of the [`ReactiveMP.compute_product_of_two_messages`](@ref) +- `fold_strategy`: defines the strategy (or simply speaking the direction) of the messages product for [`ReactiveMP.compute_product_of_messages`](@ref), default is [`MessagesProductFromLeftToRight`](@ref). Can be a custom function that accepts a `variable`, `context` and collection of `messages` and does arbitrary order, but still needs to call the [`ReactiveMP.compute_product_of_two_messages`](@ref) under the hood (unless you do some experimental stuff). By the way it is called __fold__ to reflect the computer science term with "left-fold" or "right-fold" (and we use the builtin Julia `foldl` and `foldr` functions for that). +- `callbacks`: callbacks handler, see [`ReactiveMP.invoke_callback`](@ref) for more details. + +See also: [`ReactiveMP.compute_product_of_messages`](@ref), [`ReactiveMP.compute_product_of_two_messages`] +""" +Base.@kwdef struct MessageProductContext{C, F, S, L, A} + prod_constraint::C = BayesBase.GenericProd() + form_constraint::F = UnspecifiedFormConstraint() + form_constraint_check_strategy::S = FormConstraintCheckLast() + fold_strategy::L = MessagesProductFromLeftToRight() + callbacks::A = nothing +end + +""" + compute_product_of_two_messages(variable::AbstractVariable, context::MessageProductContext, left::Message, right::Message) + +Computes the product of two messages `left` and `right` for a given `variable` using the provided `context`. +Returns a new message with the result of the multiplication (not necessarily normalized). +Applies `context.form_constraint` if `context.form_constraint_check_strategy` is set to [`ReactiveMP.FormConstraintCheckEach`](@ref). + +The `variable` argument identifies which variable this product is being computed for, which is useful for callbacks (see [`ReactiveMP.BeforeProductOfTwoMessages`](@ref)). + +## `is_clamped` and `is_initial` + +The [`ReactiveMP.Message`](@ref) carries the `is_clamped` and `is_initial` flags. +The rules for the product are the following: +- If both messages are clamped, the result is clamped, OR +- If both messages are either clamped or initial, the result is initial, OR +- The result is neither clamped nor initial -Multiplies two messages `left` and `right` using a given product strategy `prod_strategy`. -Returns a new message with the result of the multiplication. Note that the resulting message is not necessarily normalized. +See: [`ReactiveMP.MessageProductContext`](@ref), [`ReactiveMP.compute_product_of_messages`](@ref) """ -function multiply_messages(prod_strategy, left::Message, right::Message) +function compute_product_of_two_messages( + variable::AbstractVariable, + context::MessageProductContext, + left::Message, + right::Message, +) + invoke_callback(context.callbacks, BeforeProductOfTwoMessages(), variable, context, left, right) + # We propagate clamped message, in case if both are clamped is_prod_clamped = is_clamped(left) && is_clamped(right) # We propagate initial message, in case if both are initial or left is initial and right is clameped or vice-versa @@ -130,7 +176,13 @@ function multiply_messages(prod_strategy, left::Message, right::Message) # process distributions left_dist = getdata(left) right_dist = getdata(right) - new_dist = prod(prod_strategy, left_dist, right_dist) + new_dist = prod(context.prod_constraint, left_dist, right_dist) + + if context.form_constraint_check_strategy === FormConstraintCheckEach() + invoke_callback(context.callbacks, BeforeFormConstraintApplied(), variable, context, FormConstraintCheckEach(), new_dist) + new_dist = constrain_form(context.form_constraint, new_dist) + invoke_callback(context.callbacks, AfterFormConstraintApplied(), variable, context, FormConstraintCheckEach(), new_dist) + end # process addons left_addons = getaddons(left) @@ -140,59 +192,113 @@ function multiply_messages(prod_strategy, left::Message, right::Message) new_addons = multiply_addons( left_addons, right_addons, new_dist, left_dist, right_dist ) + result = Message(new_dist, is_prod_clamped, is_prod_initial, new_addons) - return Message(new_dist, is_prod_clamped, is_prod_initial, new_addons) + invoke_callback(context.callbacks, AfterProductOfTwoMessages(), variable, context, left, right, result, new_addons) + + return result end -constrain_form_as_message(message::Message, form_constraint) = Message( - constrain_form(form_constraint, getdata(message)), - is_clamped(message), - is_initial(message), - getaddons(message), +# Sometimes we call the product on the `DeferredMessage` that need to be casted to a `Message` +function compute_product_of_two_messages( + variable::AbstractVariable, context::MessageProductContext, left, right ) - -# Note: we need extra Base.Generator(as_message, messages) step here, because some of the messages might be VMP messages -# We want to cast it explicitly to a Message structure (which as_message does in case of DeferredMessage) -# We use with Base.Generator to reduce an amount of memory used by this procedure since Generator generates items lazily -prod_foldl_reduce(prod_constraint, form_constraint, ::FormConstraintCheckEach) = - (messages) -> foldl( - (left, right) -> constrain_form_as_message( - multiply_messages(prod_constraint, left, right), - form_constraint, - ), - Base.Generator(as_message, messages), + return compute_product_of_two_messages( + variable, context, as_message(left), as_message(right) ) +end + +""" + compute_product_of_messages(variable::AbstractVariable, context::MessageProductContext, messages) + +Computes the product of a **collection** of messages for a given `variable` (as opposed to [`ReactiveMP.compute_product_of_two_messages`](@ref), which handles exactly **two** messages). Uses `context.fold_strategy` to determine the order in which [`ReactiveMP.compute_product_of_two_messages`](@ref) is called. By default this is [`ReactiveMP.MessagesProductFromLeftToRight`](@ref), but can be set to an arbitrary function that accepts `variable`, `context` and `messages` and which **must** call [`ReactiveMP.compute_product_of_two_messages`](@ref) under the hood. + +See also: [`ReactiveMP.compute_product_of_two_messages`](@ref), [`ReactiveMP.MessagesProductFromLeftToRight`](@ref) +""" +function compute_product_of_messages( + variable::AbstractVariable, context::MessageProductContext, messages +) + invoke_callback(context.callbacks, BeforeProductOfMessages(), variable, context, messages) -prod_foldl_reduce(prod_constraint, form_constraint, ::FormConstraintCheckLast) = - (messages) -> constrain_form_as_message( - foldl( - (left, right) -> - multiply_messages(prod_constraint, left, right), - Base.Generator(as_message, messages), + result = as_message( + compute_product_of_messages( + context.fold_strategy, variable, context, messages ), - form_constraint, ) -prod_foldr_reduce(prod_constraint, form_constraint, ::FormConstraintCheckEach) = - (messages) -> foldr( - (left, right) -> constrain_form_as_message( - multiply_messages(prod_constraint, left, right), - form_constraint, - ), - Base.Generator(as_message, messages), + if context.form_constraint_check_strategy === FormConstraintCheckLast() + dist = getdata(result) + invoke_callback(context.callbacks, BeforeFormConstraintApplied(), variable, context, FormConstraintCheckLast(), dist) + constrained_dist = constrain_form(context.form_constraint, dist) + invoke_callback(context.callbacks, AfterFormConstraintApplied(), variable, context, FormConstraintCheckLast(), constrained_dist) + result = Message( + constrained_dist, + is_clamped(result), + is_initial(result), + getaddons(result), + ) + end + + invoke_callback(context.callbacks, AfterProductOfMessages(), variable, context, messages, result) + + return result +end + +""" + MessagesProductFromLeftToRight() + +The default fold strategy for [`ReactiveMP.MessageProductContext`](@ref). Computes the product of messages from left to right using `foldl` within [`ReactiveMP.compute_product_of_messages`](@ref). +""" +struct MessagesProductFromLeftToRight end + +function compute_product_of_messages( + ::MessagesProductFromLeftToRight, + variable::AbstractVariable, + context::MessageProductContext, + messages, +) + return foldl( + (left, right) -> + compute_product_of_two_messages(variable, context, left, right), + messages, ) +end -prod_foldr_reduce(prod_constraint, form_constraint, ::FormConstraintCheckLast) = - (messages) -> constrain_form_as_message( - foldr( - (left, right) -> - multiply_messages(prod_constraint, left, right), - Base.Generator(as_message, messages), - ), - form_constraint, +""" + MessagesProductFromRightToLeft() + +Alternative fold strategy for [`ReactiveMP.MessageProductContext`](@ref). Computes the product of messages from right to left using `foldr` within [`ReactiveMP.compute_product_of_messages`](@ref). +""" +struct MessagesProductFromRightToLeft end + +function compute_product_of_messages( + ::MessagesProductFromRightToLeft, + variable::AbstractVariable, + context::MessageProductContext, + messages, +) + return foldr( + (left, right) -> + compute_product_of_two_messages(variable, context, left, right), + messages, ) +end + +""" + compute_product_of_messages(f::Function, variable::AbstractVariable, context::MessageProductContext, messages) -# Base.:*(m1::Message, m2::Message) = multiply_messages(m1, m2) +Custom fold strategy for [`ReactiveMP.compute_product_of_messages`](@ref). When `context.fold_strategy` is set to a `Function`, +it will be called with `variable`, `context` and `messages` as arguments. The function must call +[`ReactiveMP.compute_product_of_two_messages`](@ref) under the hood to compute the pairwise products. +""" +function compute_product_of_messages( + f::Function, + variable::AbstractVariable, + context::MessageProductContext, + messages, +) + return f(variable, context, messages) +end Distributions.pdf(message::Message, x) = Distributions.pdf(getdata(message), x) Distributions.logpdf(message::Message, x) = Distributions.logpdf(getdata(message), x) @@ -246,9 +352,8 @@ mutable struct DeferredMessage{R, S, F} <: AbstractMessage cache :: Union{Nothing, Message} end -DeferredMessage(messages::R, marginals::S, mappingFn::F) where {R, S, F} = DeferredMessage( - messages, marginals, mappingFn, nothing -) +DeferredMessage(messages::R, marginals::S, mappingFn::F) where {R, S, F} = + DeferredMessage(messages, marginals, mappingFn, nothing) function Base.show(io::IO, message::DeferredMessage) cache = getcache(message) @@ -298,17 +403,14 @@ struct MessageObservable{M <: AbstractMessage} <: Subscribable{M} stream :: LazyObservable{M} end -MessageObservable(::Type{M} = AbstractMessage) where {M} = MessageObservable{M}( - RecentSubject(M), lazy(M) -) +MessageObservable(::Type{M} = AbstractMessage) where {M} = + MessageObservable{M}(RecentSubject(M), lazy(M)) -Rocket.getrecent(observable::MessageObservable) = Rocket.getrecent( - observable.subject -) +Rocket.getrecent(observable::MessageObservable) = + Rocket.getrecent(observable.subject) -@inline Rocket.on_subscribe!(observable::MessageObservable, actor) = subscribe!( - observable.stream, actor -) +@inline Rocket.on_subscribe!(observable::MessageObservable, actor) = + subscribe!(observable.stream, actor) @inline Rocket.subscribe!(observable::MessageObservable, actor::Rocket.Actor{<:AbstractMessage}) = Rocket.on_subscribe!(observable.stream, actor) @inline Rocket.subscribe!(observable::MessageObservable, actor::Rocket.NextActor{<:AbstractMessage}) = Rocket.on_subscribe!(observable.stream, actor) @@ -351,7 +453,7 @@ outgoing `Message` from given input messages and marginals using the appropriate See also: [`Message`](@ref), [`DeferredMessage`](@ref) """ -struct MessageMapping{F, T, C, N, M, A, X, R, K} +struct MessageMapping{F, T, C, N, M, A, X, R, K, E} vtag :: T vconstraint :: C msgs_names :: N @@ -360,6 +462,7 @@ struct MessageMapping{F, T, C, N, M, A, X, R, K} addons :: X factornode :: R rulefallback :: K + callbacks :: E end message_mapping_fform(::MessageMapping{F}) where {F} = F @@ -425,8 +528,9 @@ function MessageMapping( addons::X, factornode::R, rulefallback::K, -) where {F, T, C, N, M, A, X, R, K} - return MessageMapping{F, T, C, N, M, A, X, R, K}( + callbacks::E, +) where {F, T, C, N, M, A, X, R, K, E} + return MessageMapping{F, T, C, N, M, A, X, R, K, E}( vtag, vconstraint, msgs_names, @@ -435,6 +539,7 @@ function MessageMapping( addons, factornode, rulefallback, + callbacks, ) end @@ -448,8 +553,9 @@ function MessageMapping( addons::X, factornode::R, rulefallback::K, -) where {F <: Function, T, C, N, M, A, X, R, K} - return MessageMapping{F, T, C, N, M, A, X, R, K}( + callbacks::E, +) where {F <: Function, T, C, N, M, A, X, R, K, E} + return MessageMapping{F, T, C, N, M, A, X, R, K, E}( vtag, vconstraint, msgs_names, @@ -458,6 +564,7 @@ function MessageMapping( addons, factornode, rulefallback, + callbacks, ) end @@ -473,6 +580,13 @@ function (mapping::MessageMapping)(messages, marginals) __check_all(is_clamped_or_initial, marginals) ) + invoke_callback( + mapping.callbacks, + BeforeMessageRuleCallback(), + mapping, + messages, + marginals, + ) result, addons = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages))) @@ -511,6 +625,15 @@ function (mapping::MessageMapping)(messages, marginals) addons = message_mapping_addons( mapping, getdata(messages), getdata(marginals), result, addons ) + invoke_callback( + mapping.callbacks, + AfterMessageRuleCallback(), + mapping, + messages, + marginals, + result, + addons, + ) return Message(result, is_message_clamped, is_message_initial, addons) end diff --git a/src/nodes/dependencies.jl b/src/nodes/dependencies.jl index 10d7ca9c7..8be089b46 100644 --- a/src/nodes/dependencies.jl +++ b/src/nodes/dependencies.jl @@ -19,9 +19,9 @@ function __collect_latest_updates(f::F, collection::Tuple) where {F} (nothing, of(nothing)) else ( - Val{map(name, collection)}(), - combineLatestUpdates(map(f, collection), PushNew()), - ) + Val{map(name, collection)}(), + combineLatestUpdates(map(f, collection), PushNew()), + ) end end @@ -31,6 +31,7 @@ function activate!(dependencies::FunctionalDependencies, factornode, options) scheduler = getscheduler(options) addons = getaddons(options) rulefallback = getrulefallback(options) + callbacks = getcallbacks(options) fform = functionalform(factornode) meta = collect_meta(fform, getmetadata(options)) pipeline = collect_pipeline(fform, getpipeline(options)) @@ -63,6 +64,7 @@ function activate!(dependencies::FunctionalDependencies, factornode, options) addons, node_if_required(fform, factornode), rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap diff --git a/src/nodes/nodes.jl b/src/nodes/nodes.jl index 8346cbd23..d215464ba 100644 --- a/src/nodes/nodes.jl +++ b/src/nodes/nodes.jl @@ -274,13 +274,14 @@ function prepare_interfaces_check_num_inputarguments( ) end -struct FactorNodeActivationOptions{M, D, P, A, S, R} +struct FactorNodeActivationOptions{M, D, P, A, S, R, E} metadata::M dependencies::D pipeline::P addons::A scheduler::S rulefallback::R + callbacks::E end getmetadata(options::FactorNodeActivationOptions) = options.metadata @@ -289,6 +290,7 @@ getpipeline(options::FactorNodeActivationOptions) = options.pipeline getaddons(options::FactorNodeActivationOptions) = options.addons getscheduler(options::FactorNodeActivationOptions) = options.scheduler getrulefallback(options::FactorNodeActivationOptions) = options.rulefallback +getcallbacks(options::FactorNodeActivationOptions) = options.callbacks # Users can override the dependencies if they want to collect_functional_dependencies(fform::F, options::FactorNodeActivationOptions) where {F} = collect_functional_dependencies( diff --git a/src/nodes/predefined/delta/delta.jl b/src/nodes/predefined/delta/delta.jl index b664bcca9..542e08582 100644 --- a/src/nodes/predefined/delta/delta.jl +++ b/src/nodes/predefined/delta/delta.jl @@ -305,6 +305,7 @@ function activate!( scheduler = getscheduler(options) addons = getaddons(options) rulefallback = getrulefallback(options) + callbacks = getcallbacks(options) # First we declare local marginal for `out` edge deltafn_apply_layout( @@ -316,6 +317,7 @@ function activate!( scheduler, addons, rulefallback, + callbacks, ) # Second we declare how to compute a joint marginal over all inbound edges @@ -328,6 +330,7 @@ function activate!( scheduler, addons, rulefallback, + callbacks, ) # Second we declare message passing logic for out interface @@ -340,6 +343,7 @@ function activate!( scheduler, addons, rulefallback, + callbacks, ) # At last we declare message passing logic for input interfaces @@ -352,6 +356,7 @@ function activate!( scheduler, addons, rulefallback, + callbacks, ) end diff --git a/src/nodes/predefined/delta/layouts/cvi.jl b/src/nodes/predefined/delta/layouts/cvi.jl index 36c3f5525..468f2b6cc 100644 --- a/src/nodes/predefined/delta/layouts/cvi.jl +++ b/src/nodes/predefined/delta/layouts/cvi.jl @@ -25,6 +25,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), @@ -35,6 +36,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) end @@ -48,6 +50,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), @@ -58,6 +61,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) end @@ -71,6 +75,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) let interface = factornode.out @@ -101,6 +106,7 @@ function deltafn_apply_layout( addons, factornode, rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap @@ -128,6 +134,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), @@ -138,5 +145,6 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) end diff --git a/src/nodes/predefined/delta/layouts/default.jl b/src/nodes/predefined/delta/layouts/default.jl index 315a8e3ee..b94fd3203 100644 --- a/src/nodes/predefined/delta/layouts/default.jl +++ b/src/nodes/predefined/delta/layouts/default.jl @@ -53,6 +53,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) let out = factornode.out, localmarginal = factornode.localmarginals.marginals[1] @@ -71,6 +72,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) let out = factornode.out, ins = factornode.ins, @@ -107,6 +109,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) let out = factornode.out, ins = factornode.ins @@ -137,6 +140,7 @@ function deltafn_apply_layout( addons, factornode, rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap @@ -164,6 +168,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) # For each outbound message from `in_k` edge we need an inbound message on this edge and a joint marginal over `:ins` edges @@ -193,6 +198,7 @@ function deltafn_apply_layout( addons, factornode, rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap @@ -236,6 +242,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), @@ -246,6 +253,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) end @@ -258,6 +266,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), @@ -268,6 +277,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) end @@ -280,6 +290,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), @@ -290,6 +301,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) end @@ -303,6 +315,7 @@ function deltafn_apply_layout( scheduler, addons, rulefallback, + callbacks, ) where {F} N = length(factornode.ins) @@ -344,6 +357,7 @@ function deltafn_apply_layout( addons, factornode, rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap diff --git a/src/rule.jl b/src/rule.jl index e6179ef9e..60296c6da 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -1839,8 +1839,7 @@ function get_message_types_from_rule_method(m::Method) "{<:ManyOf{<:Tuple{Vararg{" * x[4:end] * ", N}}}}" else x - end, - tmp3, + end, tmp3 ) tmp5 = map(x -> occursin("xyz", x) ? x[1:(end - 3)] : x, tmp4) return map(x -> isempty(x) ? "Any" : x, map(x -> x[4:(end - 1)], tmp5)) @@ -1877,8 +1876,7 @@ function get_marginal_types_from_rule_method(m::Method) "{<:ManyOf{<:Tuple{Vararg{" * x[4:end] * ", N}}}}" else x - end, - tmp3, + end, tmp3 ) tmp5 = map(x -> occursin("xyz", x) ? x[1:(end - 3)] : x, tmp4) return map(x -> isempty(x) ? "Any" : x, map(x -> x[4:(end - 1)], tmp5)) diff --git a/src/variable.jl b/src/variable.jl new file mode 100644 index 000000000..4d78b0160 --- /dev/null +++ b/src/variable.jl @@ -0,0 +1,10 @@ +""" + AbstractVariable + +An abstract supertype for all variable types in the factor graph. +Concrete subtypes include: +- [`ReactiveMP.RandomVariable`](@ref) +- [`ReactiveMP.ConstVariable`](@ref) +- [`ReactiveMP.DataVariable`](@ref). +""" +abstract type AbstractVariable end diff --git a/src/variables/constant.jl b/src/variables/constant.jl index d52528092..e1760b20e 100644 --- a/src/variables/constant.jl +++ b/src/variables/constant.jl @@ -1,21 +1,36 @@ export constvar, ConstVariable +""" + ConstVariable <: AbstractVariable + +Represents a constant (clamped) variable in the factor graph. The value is fixed at creation time and +wrapped in a `PointMass` distribution. Messages and marginals from this variable are always marked as clamped. +Use [`constvar`](@ref) to create an instance. + +See also: [`ReactiveMP.RandomVariable`](@ref), [`ReactiveMP.DataVariable`](@ref) +""" mutable struct ConstVariable <: AbstractVariable marginal :: MarginalObservable messageout :: MessageObservable constant :: Any nconnected :: Int + label :: Any end -function ConstVariable(constant) +function ConstVariable(constant; label = nothing) marginal = MarginalObservable() connect!(marginal, of(Marginal(PointMass(constant), true, false, nothing))) messageout = MessageObservable(AbstractMessage) connect!(messageout, of(Message(PointMass(constant), true, false, nothing))) - return ConstVariable(marginal, messageout, constant, 0) + return ConstVariable(marginal, messageout, constant, 0, label) end -constvar(constant) = ConstVariable(constant) +""" + constvar(constant; label = nothing) + +Creates a new [`ReactiveMP.ConstVariable`](@ref) with the given `constant` value and an optional `label` for identification. +""" +constvar(constant; label = nothing) = ConstVariable(constant; label = label) degree(constvar::ConstVariable) = constvar.nconnected getconst(constvar::ConstVariable) = constvar.constant diff --git a/src/variables/data.jl b/src/variables/data.jl index fb3cb4656..7f892540b 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -1,13 +1,22 @@ export datavar, DataVariable, update!, DataVariableActivationOptions +""" + DataVariable <: AbstractVariable + +Represents an observed variable in the factor graph. Unlike [`ReactiveMP.ConstVariable`](@ref), the data is not fixed +at creation time and can be updated later via [`update!`](@ref). Use [`datavar`](@ref) to create an instance. + +See also: [`ReactiveMP.RandomVariable`](@ref), [`ReactiveMP.ConstVariable`](@ref) +""" mutable struct DataVariable{M, P} <: AbstractVariable input_messages :: Vector{MessageObservable{AbstractMessage}} marginal :: MarginalObservable messageout :: M prediction :: P + label :: Any end -function DataVariable() +function DataVariable(; label = nothing) messageout = RecentSubject(Message) marginal = MarginalObservable() prediction = MarginalObservable() @@ -16,10 +25,16 @@ function DataVariable() marginal, messageout, prediction, + label, ) end -datavar() = DataVariable() +""" + datavar(; label = nothing) + +Creates a new [`ReactiveMP.DataVariable`](@ref) with an optional `label` for identification. +""" +datavar(; label = nothing) = DataVariable(; label = label) degree(datavar::DataVariable) = length(datavar.input_messages) @@ -51,9 +66,8 @@ struct DataVariableActivationOptions args end -DataVariableActivationOptions() = DataVariableActivationOptions( - false, false, nothing, nothing -) +DataVariableActivationOptions() = + DataVariableActivationOptions(false, false, nothing, nothing) function activate!( datavar::DataVariable, options::DataVariableActivationOptions @@ -82,13 +96,11 @@ function activate!( return nothing end -__link_getmarginal(constant) = of( - Marginal(PointMass(constant), true, false, nothing) -) +__link_getmarginal(constant) = + of(Marginal(PointMass(constant), true, false, nothing)) __link_getmarginal(l::AbstractVariable) = getmarginal(l, IncludeAll()) -__link_getmarginal(l::AbstractArray{<:AbstractVariable}) = getmarginals( - l, IncludeAll() -) +__link_getmarginal(l::AbstractArray{<:AbstractVariable}) = + getmarginals(l, IncludeAll()) __apply_link(f::F, args) where {F} = __apply_link(f, getdata.(args)) __apply_link(f::F, args::NTuple{N, PointMass}) where {F, N} = f(mean.(args)...) @@ -97,6 +109,14 @@ _getmarginal(datavar::DataVariable) = datavar.marginal _setmarginal!(::DataVariable, observable) = error("It is not possible to set a marginal stream for `DataVariable`") _makemarginal(::DataVariable) = error("It is not possible to make marginal stream for `DataVariable`") +""" + update!(datavar::DataVariable, data) + update!(datavars::AbstractArray{<:DataVariable}, data::AbstractArray) + +Provides a new observation to a [`ReactiveMP.DataVariable`](@ref) (or an array of data variables). +The `data` is wrapped in a `PointMass` distribution and pushed as a new message. +Pass `missing` to indicate that the observation is not available. +""" update!(datavar::DataVariable, data) = update!(datavar, PointMass(data)) update!(datavar::DataVariable, data::PointMass) = next!(datavar.messageout, Message(data, false, false, nothing)) update!(datavar::DataVariable, ::Missing) = next!(datavar.messageout, Message(missing, false, false, nothing)) @@ -116,13 +136,6 @@ function update!(datavars::AbstractArray{<:DataVariable}, data::Missing) end end -marginal_prod_fn(datavar::DataVariable) = marginal_prod_fn( - FoldLeftProdStrategy(), - GenericProd(), - UnspecifiedFormConstraint(), - FormConstraintCheckLast(), -) - _getprediction(datavar::DataVariable) = datavar.prediction _setprediction!(datavar::DataVariable, observable) = connect!(_getprediction(datavar), observable) -_makeprediction(datavar::DataVariable) = collectLatest(AbstractMessage, Marginal, datavar.input_messages, marginal_prod_fn(datavar)) +_makeprediction(datavar::DataVariable) = collectLatest(AbstractMessage, Marginal, datavar.input_messages, (messages) -> as_marginal(compute_product_of_messages(datavar, MessageProductContext(), messages))) diff --git a/src/variables/variable.jl b/src/variables/generic.jl similarity index 73% rename from src/variables/variable.jl rename to src/variables/generic.jl index 9eedb6cb6..31be52daa 100644 --- a/src/variables/variable.jl +++ b/src/variables/generic.jl @@ -1,46 +1,13 @@ export AbstractVariable, degree -export FoldLeftProdStrategy, FoldRightProdStrategy, CustomProdStrategy export getprediction, getpredictions, getmarginal, getmarginals export setmarginal!, setmarginals!, setmessage!, setmessages! using Rocket -abstract type AbstractVariable end - ## Base interface extensions Base.broadcastable(v::AbstractVariable) = Ref(v) -## Messages to Marginal product strategies - -struct FoldLeftProdStrategy end -struct FoldRightProdStrategy end - -struct CustomProdStrategy{F} - prod_callback_generator::F -end - -""" - messages_prod_fn(strategy, prod_constraint, form_constraint, form_check_strategy) - -Returns a suitable prod computation function for a given strategy and constraints -""" -function messages_prod_fn end - -messages_prod_fn(::FoldLeftProdStrategy, prod_constraint, form_constraint, form_check_strategy) = prod_foldl_reduce(prod_constraint, form_constraint, form_check_strategy) -messages_prod_fn(::FoldRightProdStrategy, prod_constraint, form_constraint, form_check_strategy) = prod_foldr_reduce(prod_constraint, form_constraint, form_check_strategy) -messages_prod_fn(strategy::CustomProdStrategy, prod_constraint, form_constraint, form_check_strategy) = strategy.prod_callback_generator(prod_constraint, form_constraint, form_check_strategy) - -function marginal_prod_fn( - strategy, prod_constraint, form_constraint, form_check_strategy -) - return let prod_fn = messages_prod_fn( - strategy, prod_constraint, form_constraint, form_check_strategy - ) - return (messages) -> as_marginal(prod_fn(messages)) - end -end - # Helper functions israndom(v::AbstractArray{<:AbstractVariable}) = all(israndom, v) diff --git a/src/variables/random.jl b/src/variables/random.jl index 827260b6a..e95fb9460 100644 --- a/src/variables/random.jl +++ b/src/variables/random.jl @@ -2,17 +2,32 @@ export randomvar, RandomVariable, RandomVariableActivationOptions ## Random variable implementation +""" + RandomVariable <: AbstractVariable + +Represents a latent (unobserved) variable in the factor graph. Random variables collect incoming and outgoing messages +from connected factor nodes and maintain a marginal belief. Use [`randomvar`](@ref) to create an instance. + +See also: [`ReactiveMP.ConstVariable`](@ref), [`ReactiveMP.DataVariable`](@ref) +""" mutable struct RandomVariable <: AbstractVariable - input_messages::Vector{MessageObservable{AbstractMessage}} - output_messages::Vector{MessageObservable{Message}} - marginal::MarginalObservable + input_messages :: Vector{MessageObservable{AbstractMessage}} + output_messages :: Vector{MessageObservable{Message}} + marginal :: MarginalObservable + label :: Any end -function randomvar() +""" + randomvar(; label = nothing) + +Creates a new [`ReactiveMP.RandomVariable`](@ref) with an optional `label` for identification. +""" +function randomvar(; label = nothing) return RandomVariable( Vector{MessageObservable{AbstractMessage}}(), Vector{MessageObservable{Message}}(), MarginalObservable(), + label, ) end @@ -39,27 +54,16 @@ function messageout(randomvar::RandomVariable, index::Int) return randomvar.output_messages[index] end -const DefaultMessageProdFn = messages_prod_fn( - FoldLeftProdStrategy(), - GenericProd(), - UnspecifiedFormConstraint(), - FormConstraintCheckLast(), -) -const DefaultMarginalProdFn = marginal_prod_fn( - FoldLeftProdStrategy(), - GenericProd(), - UnspecifiedFormConstraint(), - FormConstraintCheckLast(), -) - -struct RandomVariableActivationOptions{S, F, M} +struct RandomVariableActivationOptions{ + S, F <: MessageProductContext, M <: MessageProductContext +} scheduler::S - message_prod_fn::F - marginal_prod_fn::M + prod_context_for_message_computation::F + prod_context_for_marginal_computation::M end RandomVariableActivationOptions() = RandomVariableActivationOptions( - AsapScheduler(), DefaultMessageProdFn, DefaultMarginalProdFn + AsapScheduler(), MessageProductContext(), MessageProductContext() ) function activate!( @@ -77,7 +81,11 @@ function activate!( chain = EqualityChain( randomvar.input_messages, schedule_on(options.scheduler), - options.message_prod_fn, + (messages) -> compute_product_of_messages( + randomvar, + options.prod_context_for_message_computation, + messages, + ), ) initialize!(chain, outputmsgs) elseif length(randomvar.input_messages) == 1 @@ -98,9 +106,36 @@ function activate!( end _getmarginal(randomvar::RandomVariable) = randomvar.marginal -_setmarginal!(randomvar::RandomVariable, observable) = connect!( - _getmarginal(randomvar), observable +_setmarginal!(randomvar::RandomVariable, observable) = + connect!(_getmarginal(randomvar), observable) + +function _compute_marginal_from_messages( + randomvar::RandomVariable, + options::RandomVariableActivationOptions, + messages, ) + context = options.prod_context_for_marginal_computation + invoke_callback( + context.callbacks, + BeforeMarginalComputation(), + randomvar, + context, + messages, + ) + result = as_marginal( + compute_product_of_messages(randomvar, context, messages) + ) + invoke_callback( + context.callbacks, + AfterMarginalComputation(), + randomvar, + context, + messages, + result, + ) + return result +end + _makemarginal( randomvar::RandomVariable, options::RandomVariableActivationOptions ) = begin @@ -108,7 +143,8 @@ _makemarginal( AbstractMessage, Marginal, randomvar.input_messages, - options.marginal_prod_fn, + (messages) -> + _compute_marginal_from_messages(randomvar, options, messages), reset_vstatus, ) end diff --git a/test/addons/memory_tests.jl b/test/addons/memory_tests.jl index 51e00a6e6..8a1c8ba17 100644 --- a/test/addons/memory_tests.jl +++ b/test/addons/memory_tests.jl @@ -22,6 +22,7 @@ AddonMemory(), nothing, nothing, + nothing, ) messages = (Gamma(1.0, 1.0), NormalMeanVariance(0.0, 1.0)) diff --git a/test/callbacks_tests.jl b/test/callbacks_tests.jl new file mode 100644 index 000000000..91afb13ba --- /dev/null +++ b/test/callbacks_tests.jl @@ -0,0 +1,259 @@ +@testitem "Callbacks handler should do absolutely nothing if no handler exists" begin + import ReactiveMP: invoke_callback + + args = (1, "Hello", [1, 2, 3], [1;;]) + callback_handler = nothing + + function bar(args) + invoke_callback(callback_handler, Val(:my_event), args) + return nothing + end + + bar(args) + + @test bar(args) === nothing + @test @allocated(bar(args)) === 0 +end + +@testitem "It should be possible to define custom callback handlers" begin + import ReactiveMP: invoke_callback + + struct MyCallbackHandler + events + end + + function ReactiveMP.invoke_callback( + handler::MyCallbackHandler, ::Val{E}, args... + ) where {E} + push!(handler.events, (event = E, args = args)) + return nothing + end + + handler = MyCallbackHandler([]) + + @test invoke_callback(handler, Val{:event1}(), 1, 1) === nothing + @test invoke_callback(handler, Val{:event2}(), 2, 3) === nothing + + @test length(handler.events) === 2 + @test handler.events[1].event === :event1 + @test handler.events[1].args === (1, 1) + @test handler.events[2].event === :event2 + @test handler.events[2].args === (2, 3) + + @test_throws MethodError invoke_callback( + handler, "unsupported type of event", 1, 2 + ) +end + +@testitem "invoke_callback error hint for forgotten trailing comma in NamedTuple" begin + import ReactiveMP: invoke_callback + + # This simulates the common mistake: `(before_product_of_messages = fn)` without trailing comma. + # Julia parses this as a plain assignment, so `callbacks` becomes just `fn` (a Function). + callbacks = (before_product_of_messages = (args...) -> nothing) + + # Verify that Julia indeed parsed this as a Function, not a NamedTuple + @test callbacks isa Function + @test !(callbacks isa NamedTuple) + + err = try + invoke_callback(callbacks, Val(:before_product_of_messages), 1, 2) + catch e + e + end + @test err isa MethodError + + # Check that the error hint mentions both possible causes + hint_message = sprint(showerror, err) + @test occursin("invoke_callback", hint_message) + @test occursin("trailing comma", hint_message) +end + +@testitem "invoke_callback error hint for custom handler with missing method" begin + import ReactiveMP: invoke_callback + + # Custom handler that only implements invoke_callback for :event1 but not :event2 + struct IncompleteHandler end + + ReactiveMP.invoke_callback(::IncompleteHandler, ::Val{:event1}, args...) = + nothing + + handler = IncompleteHandler() + + # :event1 works fine + @test invoke_callback(handler, Val(:event1), 1, 2) === nothing + + # :event2 is not implemented — should hit MethodError with a helpful hint + err = try + invoke_callback(handler, Val(:event2), 1, 2) + catch e + e + end + @test err isa MethodError + + hint_message = sprint(showerror, err) + @test occursin( + r"ReactiveMP\.invoke_callback\(::.*IncompleteHandler, ::Val\{:event2\}, args\.\.\.\) = \.\.\.", + hint_message, + ) + @test occursin( + "You meant to pass a `NamedTuple` as the callbacks handler but forgot the trailing comma.", + hint_message, + ) +end + +@testitem "NamedTuple should be a supported event handler" begin + import ReactiveMP: invoke_callback + + callback_handler = ( + sum_event = (args...) -> sum(args), prod_event = (args...) -> prod(args) + ) + + @test @inferred( + invoke_callback(callback_handler, Val{:sum_event}(), 1, 2) + ) == 3 + @test @inferred( + invoke_callback(callback_handler, Val{:sum_event}(), 1, 2, 3) + ) == 6 + @test @inferred( + invoke_callback(callback_handler, Val{:prod_event}(), 1, 2) + ) == 2 + @test @inferred( + invoke_callback(callback_handler, Val{:prod_event}(), 1, 2, 5) + ) == 10 + @test @inferred( + invoke_callback(callback_handler, Val{:other_event}(), 1, 2, 3) + ) === nothing +end + +@testitem "It should be possible to merge callback handlers" begin + import ReactiveMP: invoke_callback, merge_callbacks + + # listens to event 1 and event 2 + handler1_events = [] + callback_handler1 = ( + event1 = (args...) -> push!(handler1_events, :event1), + event2 = (args...) -> push!(handler1_events, :event2), + ) + + # listens to event3 and event 2 + handler2_events = [] + callback_handler2 = ( + event3 = (args...) -> push!(handler2_events, :event3), + event2 = (args...) -> push!(handler2_events, :event2), + ) + + # only listens to event 2 + struct MyCustomHandler + events + end + + ReactiveMP.invoke_callback(::MyCustomHandler, event, args...) = nothing + ReactiveMP.invoke_callback( + handler::MyCustomHandler, event::Val{:event2}, args... + ) = push!(handler.events, :event2) + + custom_handler = MyCustomHandler([]) + + merged_handler = merge_callbacks( + callback_handler1, callback_handler2, custom_handler + ) + + for i in 1:5 + invoke_callback(merged_handler, Val(:event1), 1, 1) + invoke_callback(merged_handler, Val(:event2), "hello") + invoke_callback(merged_handler, Val(:event3), 3.0) + end + + @test length(handler1_events) == 10 + @test Set(handler1_events) == Set([:event1, :event2]) + @test length(handler2_events) == 10 + @test Set(handler2_events) == Set([:event3, :event2]) + @test length(custom_handler.events) == 5 + @test Set(custom_handler.events) == Set([:event2]) +end + +@testitem "It should be possible to reduce the result of the merged callback handlers" begin + import ReactiveMP: invoke_callback, merge_callbacks + + callback_handler1 = (event1 = (a, b) -> a + b,) + callback_handler2 = (event1 = (a, b) -> a * b,) + + merged_handler1 = merge_callbacks(callback_handler1, callback_handler2) + + @test @inferred(invoke_callback(merged_handler1, Val(:event1), 2, 3)) === + (5, 6) + + merged_handler2 = merge_callbacks( + callback_handler1, callback_handler2; reduce_fn = + + ) + + @test @inferred(invoke_callback(merged_handler2, Val(:event1), 4, 5)) === 29 + + merged_handler3 = merge_callbacks( + callback_handler1, callback_handler2; reduce_fn = * + ) + + @test @inferred( + invoke_callback(merged_handler3, Val(:event1), 1.0, 2.0) + ) === 6.0 +end + +@testitem "It should be possible to use different reduce functions for different events" begin + import ReactiveMP: invoke_callback, merge_callbacks + + callback_handler1 = (event1 = (a, b) -> a + b, event2 = (a, b) -> a - b) + callback_handler2 = (event1 = (a, b) -> a * b, event2 = (a, b) -> a / b) + + merged_handler1 = merge_callbacks(callback_handler1, callback_handler2) + + @test @inferred(invoke_callback(merged_handler1, Val(:event1), 2, 3)) === + (5, 6) + @test @inferred(invoke_callback(merged_handler1, Val(:event2), 3, 4)) === + (-1, 3 / 4) + + merged_handler2 = merge_callbacks( + callback_handler1, + callback_handler2; + reduce_fn = (event1 = +, event2 = *), + ) + + @test @inferred(invoke_callback(merged_handler2, Val(:event1), 4, 5)) === 29 + @test @inferred(invoke_callback(merged_handler2, Val(:event2), 4, 5)) === + -4 / 5 + + merged_handler3 = merge_callbacks( + callback_handler1, + callback_handler2; + reduce_fn = (event1 = *, event2 = +), + ) + + @test @inferred( + invoke_callback(merged_handler3, Val(:event1), 1.0, 2.0) + ) === 6.0 + @test @inferred( + invoke_callback(merged_handler3, Val(:event2), 1.0, 2.0) + ) === -1.0 + 1.0 / 2.0 + + merged_handler4 = merge_callbacks( + callback_handler1, callback_handler2; reduce_fn = (event1 = -,) + ) + + @test @inferred( + invoke_callback(merged_handler4, Val(:event1), 1.0, 2.0) + ) === 1.0 + @test @inferred( + invoke_callback(merged_handler4, Val(:event2), 1.0, 2.0) + ) === (-1.0, 1.0 / 2.0) + + merged_handler5 = merge_callbacks( + callback_handler1, callback_handler2; reduce_fn = (event2 = /,) + ) + + @test @inferred( + invoke_callback(merged_handler5, Val(:event1), 1.0, 2.0) + ) === (3.0, 2.0) + @test @inferred( + invoke_callback(merged_handler5, Val(:event2), 1.0, 2.0) + ) === -1.0 / (1.0 / 2.0) +end diff --git a/test/message_tests.jl b/test/message_tests.jl index 81537d3cb..2fec6daec 100644 --- a/test/message_tests.jl +++ b/test/message_tests.jl @@ -5,12 +5,17 @@ import Base: methods import Base.Iterators: repeated, product import BayesBase: xtlog, mirrorlog - import ReactiveMP: getaddons, multiply_messages, as_message + import ReactiveMP: + getaddons, + compute_product_of_two_messages, + MessageProductContext, + as_message import SpecialFunctions: loggamma @testset "Default methods" begin for clamped in (true, false), - initial in (true, false), addons in (1, 2), + initial in (true, false), + addons in (1, 2), data in (1, 1.0, Normal(0, 1), Gamma(1, 1), PointMass(1)) msg = Message(data, clamped, initial, addons) @@ -28,7 +33,8 @@ dist2 = MvNormalMeanCovariance([0.0, 1.0], [1.0 0.0; 0.0 1.0]) for clamped1 in (true, false), - clamped2 in (true, false), initial1 in (true, false), + clamped2 in (true, false), + initial1 in (true, false), initial2 in (true, false) msg1 = Message(dist1, clamped1, initial1, nothing) @@ -40,8 +46,12 @@ end end - @testset "multiply_messages" begin - × = (x, y) -> multiply_messages(GenericProd(), x, y) + @testset "compute product of two messages" begin + _testvar = ReactiveMP.randomvar() + × = + (x, y) -> compute_product_of_two_messages( + _testvar, MessageProductContext(), x, y + ) dist1 = NormalMeanVariance(randn(), rand()) dist2 = NormalMeanVariance(randn(), rand()) @@ -218,9 +228,8 @@ end end - _getpoint(rng, distribution) = _getpoint( - rng, variate_form(typeof(distribution)), distribution - ) + _getpoint(rng, distribution) = + _getpoint(rng, variate_form(typeof(distribution)), distribution) _getpoint(rng, ::Type{<:Univariate}, distribution) = 10rand(rng) _getpoint(rng, ::Type{<:Multivariate}, distribution) = 10 .* rand(rng, 2) @@ -307,6 +316,7 @@ end addons, SomeArbitraryNode(), nothing, + nothing, ) messages = (Message(NonexistingDistribution(), false, false, nothing),) @@ -328,6 +338,7 @@ end addons, SomeArbitraryNode(), rulefallback, + nothing, ) @test getdata(mapping_with_fallback(messages, marginals)) == ( @@ -343,3 +354,562 @@ end SomeArbitraryNode(), ) end + +@testitem "MessageMapping should call provided callbacks handler" begin + import ReactiveMP: MessageMapping, getdata + + struct SomeArbitraryNode end + + @node SomeArbitraryNode Stochastic [out, in] + + @rule SomeArbitraryNode(:out, Marginalisation) (m_in::Int,) = m_in + 1 + + events = [] + + callbacks = ( + before_message_rule_call = (args...) -> + push!(events, (event = :before_message_rule_call, args = args)), + after_message_rule_call = (args...) -> + push!(events, (event = :after_message_rule_call, args = args)), + ) + + mapping = MessageMapping( + SomeArbitraryNode, + Val(:out), + Marginalisation(), + Val((:in,)), + nothing, + nothing, + (), + SomeArbitraryNode(), + nothing, + callbacks, + ) + + messages = (Message(1, false, false, nothing),) + marginals = nothing + + @test getdata(mapping(messages, marginals)) == 2 + + @test events[1].event == :before_message_rule_call + @test events[1].args[1].factornode === SomeArbitraryNode() + @test events[1].args[2] === messages + @test events[1].args[3] === marginals + + @test events[2].event == :after_message_rule_call + @test events[2].args[1].factornode === SomeArbitraryNode() + @test events[2].args[2] === messages + @test events[2].args[3] === marginals + @test events[2].args[4] === 2 + @test events[2].args[5] === () +end + +@testmodule MessageProductContextUtils begin + import ReactiveMP: AbstractVariable, AbstractFormConstraint, constrain_form + import ReactiveMP.BayesBase: prod, GenericProd, isapprox + import ReactiveMP + + struct Normal + mean::Float64 + var::Float64 + end + + function prod(::GenericProd, left::Normal, right::Normal) + result_var = 1 / (1 / left.var + 1 / right.var) + result_mean = + result_var * (left.mean / left.var + right.mean / right.var) + return Normal(result_mean, result_var) + end + + function isapprox(left::Normal, right::Normal; kwargs...) + return isapprox(left.mean, right.mean; kwargs...) && + isapprox(left.var, right.var; kwargs...) + end + + struct AbstractVariableForMessageProductContextTests <: AbstractVariable end + + testvar = AbstractVariableForMessageProductContextTests() + + # A simple form constraint that adds +1 to the mean of a Normal distribution + struct AddOneToMeanConstraint <: AbstractFormConstraint end + + constrain_form(::AddOneToMeanConstraint, dist::Normal) = + Normal(dist.mean + 1, dist.var) + + # A callback handler that only records events from a specified set + struct SaveOrderOfComputationCallbacks + listen_to::Tuple + events + end + + function ReactiveMP.invoke_callback( + handler::SaveOrderOfComputationCallbacks, ::Val{E}, args... + ) where {E} + E ∈ handler.listen_to && push!(handler.events, (event = E, args = args)) + end + + export Normal, + testvar, AddOneToMeanConstraint, SaveOrderOfComputationCallbacks +end + +@testitem "MessageProductContext should compute product of two messages" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + Message, MessageProductContext, compute_product_of_two_messages, getdata + + context = MessageProductContext() + + msg1 = Message(Normal(0, 1), false, false, nothing) + msg2 = Message(Normal(0, 1), false, false, nothing) + + result = @inferred( + compute_product_of_two_messages(testvar, context, msg1, msg2) + ) + + @test result isa Message + @test getdata(result) === Normal(0, 1 / 2) +end + +@testitem "compute_message_product propagates the `is_clamped` and `is_initial` correctly" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + Message, + MessageProductContext, + compute_product_of_two_messages, + is_clamped, + is_initial + + context = MessageProductContext() + + for left_is_clamped in (true, false), + right_is_clamped in (true, false), + left_is_initial in (true, false), + right_is_initial in (true, false) + + msg1 = Message(Normal(0, 1), left_is_clamped, left_is_initial, nothing) + msg2 = Message( + Normal(0, 1), right_is_clamped, right_is_initial, nothing + ) + + result = @inferred( + compute_product_of_two_messages(testvar, context, msg1, msg2) + ) + + @test result isa Message + + expected_result_is_clamped = left_is_clamped && right_is_clamped + expected_result_is_initial = + !expected_result_is_clamped && + (left_is_clamped || left_is_initial) && + (right_is_clamped || right_is_initial) + + @test is_clamped(result) === expected_result_is_clamped + @test is_initial(result) === expected_result_is_initial + end +end + +@testitem "compute_message_product should support different folding strategies" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + MessageProductContext, + Message, + compute_product_of_messages, + compute_product_of_two_messages, + getdata + + messages = [ + Message(Normal(0, 1), false, false, nothing) + Message(Normal(0, 2), false, false, nothing) + Message(Normal(0, 3), false, false, nothing) + ] + + @testset "From left to right" begin + import ReactiveMP: MessagesProductFromLeftToRight + + listen_to = ( + :before_product_of_two_messages, :after_product_of_two_messages + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + fold_strategy = MessagesProductFromLeftToRight(), + callbacks = handler, + ) + + result = @inferred( + compute_product_of_messages(testvar, context, messages) + ) + + @test result isa Message + @test getdata(result) === Normal(0, 1 / (1 + 1 / 2 + 1 / 3)) + + # 3 messages = 2 products, each product fires a before and after callback + @test length(handler.events) == 4 + @test handler.events[1].event === :before_product_of_two_messages + @test handler.events[2].event === :after_product_of_two_messages + @test handler.events[3].event === :before_product_of_two_messages + @test handler.events[4].event === :after_product_of_two_messages + + # First product: Normal(0,1) × Normal(0,2) — left to right order + @test getdata(handler.events[1].args[3]) == Normal(0, 1) + @test getdata(handler.events[1].args[4]) == Normal(0, 2) + + # Second product: result of first × Normal(0,3) + @test getdata(handler.events[3].args[3]) == + Normal(0, 1 / (1 / 1 + 1 / 2)) + @test getdata(handler.events[3].args[4]) == Normal(0, 3) + end + + @testset "From right to left" begin + import ReactiveMP: MessagesProductFromRightToLeft + + listen_to = ( + :before_product_of_two_messages, :after_product_of_two_messages + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + fold_strategy = MessagesProductFromRightToLeft(), + callbacks = handler, + ) + + result = @inferred( + compute_product_of_messages(testvar, context, messages) + ) + + @test result isa Message + @test getdata(result) === Normal(0, 1 / (1 + 1 / 2 + 1 / 3)) + + # 3 messages = 2 products, each product fires a before and after callback + @test length(handler.events) == 4 + @test handler.events[1].event === :before_product_of_two_messages + @test handler.events[2].event === :after_product_of_two_messages + @test handler.events[3].event === :before_product_of_two_messages + @test handler.events[4].event === :after_product_of_two_messages + + # First product: Normal(0,2) × Normal(0,3) — right to left order + @test getdata(handler.events[1].args[3]) == Normal(0, 2) + @test getdata(handler.events[1].args[4]) == Normal(0, 3) + + # Second product: Normal(0,1) × result of first + @test getdata(handler.events[3].args[3]) == Normal(0, 1) + @test getdata(handler.events[3].args[4]) == + Normal(0, 1 / (1 / 2 + 1 / 3)) + end + + @testset "Custom fold strategy via Function" begin + # Custom strategy: compute (1 × 3) × 2 + custom_fold = + (variable, context, messages) -> begin + first = compute_product_of_two_messages( + variable, context, messages[1], messages[3] + ) + return compute_product_of_two_messages( + variable, context, first, messages[2] + ) + end + + listen_to = ( + :before_product_of_two_messages, :after_product_of_two_messages + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + fold_strategy = custom_fold, callbacks = handler + ) + + result = @inferred( + compute_product_of_messages(testvar, context, messages) + ) + + @test result isa Message + @test getdata(result) === Normal(0, 1 / (1 + 1 / 2 + 1 / 3)) + + @test length(handler.events) == 4 + + # First product: Normal(0,1) × Normal(0,3) + @test getdata(handler.events[1].args[3]) == Normal(0, 1) + @test getdata(handler.events[1].args[4]) == Normal(0, 3) + + # Second product: result of first × Normal(0,2) + @test getdata(handler.events[3].args[3]) == + Normal(0, 1 / (1 / 1 + 1 / 3)) + @test getdata(handler.events[3].args[4]) == Normal(0, 2) + end + + @testset "Before and after callbacks receive correct arguments" begin + import ReactiveMP: MessagesProductFromLeftToRight + + listen_to = ( + :before_product_of_two_messages, :after_product_of_two_messages + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + fold_strategy = MessagesProductFromLeftToRight(), + callbacks = handler, + ) + + msg1 = Message(Normal(0, 1), false, false, nothing) + msg2 = Message(Normal(0, 2), false, false, nothing) + + result = @inferred( + compute_product_of_messages(testvar, context, [msg1, msg2]) + ) + + @test length(handler.events) == 2 + + # Before callback: variable, context, left, right + before = handler.events[1] + @test before.event === :before_product_of_two_messages + @test before.args[1] === testvar + @test before.args[2] === context + @test getdata(before.args[3]) == Normal(0, 1) + @test getdata(before.args[4]) == Normal(0, 2) + + # After callback: variable, context, left, right, result, addons + after = handler.events[2] + @test after.event === :after_product_of_two_messages + @test after.args[1] === testvar + @test after.args[2] === context + @test getdata(after.args[3]) == Normal(0, 1) + @test getdata(after.args[4]) == Normal(0, 2) + @test after.args[5] == result + @test after.args[6] == nothing # no addons + end +end + +@testitem "Form constraint callbacks with FormConstraintCheckEach" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + MessageProductContext, + Message, + FormConstraintCheckEach, + compute_product_of_messages, + compute_product_of_two_messages, + getdata + + messages = [ + Message(Normal(0, 1), false, false, nothing) + Message(Normal(0, 2), false, false, nothing) + Message(Normal(0, 3), false, false, nothing) + ] + + @testset "CheckEach applies form constraint after each pairwise product" begin + listen_to = ( + :before_form_constraint_applied, :after_form_constraint_applied + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + form_constraint = AddOneToMeanConstraint(), + form_constraint_check_strategy = FormConstraintCheckEach(), + callbacks = handler, + ) + + result = compute_product_of_messages(testvar, context, messages) + + # Hand-computed expected values (left-to-right fold with +1 to mean after each product): + # Step 1: prod(Normal(0,1), Normal(0,2)) + # var = 1/(1 + 1/2) = 2/3, mean = (2/3)*(0 + 0) = 0 => Normal(0, 2/3) + # constraint: Normal(0 + 1, 2/3) = Normal(1, 2/3) + # Step 2: prod(Normal(1, 2/3), Normal(0,3)) + # var = 1/(3/2 + 1/3) = 6/11, mean = (6/11)*(1*3/2 + 0*1/3) = 9/11 => Normal(9/11, 6/11) + # constraint: Normal(9/11 + 1, 6/11) = Normal(20/11, 6/11) + @test getdata(result) ≈ Normal(20 / 11, 6 / 11) + + # With CheckEach and 3 messages: 2 pairwise products => 2 before + 2 after = 4 events + @test length(handler.events) == 4 + @test handler.events[1].event === :before_form_constraint_applied + @test handler.events[2].event === :after_form_constraint_applied + @test handler.events[3].event === :before_form_constraint_applied + @test handler.events[4].event === :after_form_constraint_applied + + # All form constraint events should carry the CheckEach strategy + for e in handler.events + @test e.args[3] === FormConstraintCheckEach() + end + + # First constraint: before gets Normal(0, 2/3), after gets Normal(1, 2/3) + @test handler.events[1].args[4] ≈ Normal(0, 2 / 3) + @test handler.events[2].args[4] ≈ Normal(1, 2 / 3) + + # Second constraint: before gets Normal(9/11, 6/11), after gets Normal(20/11, 6/11) + @test handler.events[3].args[4] ≈ Normal(9 / 11, 6 / 11) + @test handler.events[4].args[4] ≈ Normal(20 / 11, 6 / 11) + end +end + +@testitem "Form constraint callbacks with FormConstraintCheckLast" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + MessageProductContext, + Message, + FormConstraintCheckLast, + compute_product_of_messages, + getdata + + messages = [ + Message(Normal(0, 1), false, false, nothing) + Message(Normal(0, 2), false, false, nothing) + Message(Normal(0, 3), false, false, nothing) + ] + + @testset "CheckLast applies form constraint once at the end" begin + listen_to = ( + :before_form_constraint_applied, :after_form_constraint_applied + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + form_constraint = AddOneToMeanConstraint(), + form_constraint_check_strategy = FormConstraintCheckLast(), + callbacks = handler, + ) + + result = compute_product_of_messages(testvar, context, messages) + + # Hand-computed expected values (left-to-right fold, constraint only at the end): + # Step 1: prod(Normal(0,1), Normal(0,2)) + # var = 2/3, mean = 0 => Normal(0, 2/3) — no constraint + # Step 2: prod(Normal(0, 2/3), Normal(0,3)) + # var = 1/(3/2 + 1/3) = 6/11, mean = 0 => Normal(0, 6/11) — no constraint + # Final constraint: Normal(0 + 1, 6/11) = Normal(1, 6/11) + @test getdata(result) ≈ Normal(1, 6 / 11) + + # With CheckLast, form constraint fires only once (1 before + 1 after) + @test length(handler.events) == 2 + @test handler.events[1].event === :before_form_constraint_applied + @test handler.events[2].event === :after_form_constraint_applied + + # The strategy should be CheckLast + @test handler.events[1].args[3] === FormConstraintCheckLast() + @test handler.events[2].args[3] === FormConstraintCheckLast() + + # Before constraint: Normal(0, 6/11), after constraint: Normal(1, 6/11) + @test handler.events[1].args[4] ≈ Normal(0, 6 / 11) + @test handler.events[2].args[4] ≈ Normal(1, 6 / 11) + + # The final result matches the after-constraint distribution + @test getdata(result) ≈ handler.events[2].args[4] + end + + @testset "Before/after product of messages callbacks fire around the whole computation" begin + listen_to = (:before_product_of_messages, :after_product_of_messages) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + form_constraint = AddOneToMeanConstraint(), + form_constraint_check_strategy = FormConstraintCheckLast(), + callbacks = handler, + ) + + result = compute_product_of_messages(testvar, context, messages) + + # BeforeProductOfMessages should be the first event + @test length(handler.events) == 2 + @test handler.events[1].event === :before_product_of_messages + @test handler.events[1].args[1] === testvar + @test handler.events[1].args[2] === context + @test handler.events[1].args[3] === messages + + # AfterProductOfMessages should be the last event + @test handler.events[2].event === :after_product_of_messages + @test handler.events[2].args[1] === testvar + @test handler.events[2].args[2] === context + @test handler.events[2].args[3] === messages + @test handler.events[2].args[4] == result + end +end + +@testitem "Before/after product of messages callbacks" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + MessageProductContext, Message, compute_product_of_messages, getdata + + @testset "Fires with default context (no form constraint)" begin + listen_to = (:before_product_of_messages, :after_product_of_messages) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; callbacks = handler) + + messages = [ + Message(Normal(0, 1), false, false, nothing) + Message(Normal(0, 2), false, false, nothing) + Message(Normal(0, 3), false, false, nothing) + ] + + result = compute_product_of_messages(testvar, context, messages) + + # prod(Normal(0,1), Normal(0,2)) = Normal(0, 2/3) + # prod(Normal(0, 2/3), Normal(0,3)) = Normal(0, 6/11) + @test getdata(result) ≈ Normal(0, 6 / 11) + + @test length(handler.events) == 2 + + # Before: receives variable, context, and the original messages + @test handler.events[1].event === :before_product_of_messages + @test handler.events[1].args[1] === testvar + @test handler.events[1].args[2] === context + @test handler.events[1].args[3] === messages + + # After: receives variable, context, original messages, and the final result + @test handler.events[2].event === :after_product_of_messages + @test handler.events[2].args[1] === testvar + @test handler.events[2].args[2] === context + @test handler.events[2].args[3] === messages + @test getdata(handler.events[2].args[4]) ≈ getdata(result) + end + + @testset "Fires with two messages" begin + listen_to = (:before_product_of_messages, :after_product_of_messages) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; callbacks = handler) + + messages = [ + Message(Normal(1, 4), false, false, nothing) + Message(Normal(3, 4), false, false, nothing) + ] + + result = compute_product_of_messages(testvar, context, messages) + + # prod(Normal(1,4), Normal(3,4)): + # var = 1/(1/4 + 1/4) = 2, mean = 2*(1/4 + 3/4) = 2 + @test getdata(result) ≈ Normal(2, 2) + + @test length(handler.events) == 2 + @test handler.events[1].event === :before_product_of_messages + @test handler.events[2].event === :after_product_of_messages + @test getdata(handler.events[2].args[4]) ≈ getdata(result) + end + + @testset "Before fires before any pairwise product, after fires after all" begin + # Listen to all product-related events to verify ordering + listen_to = ( + :before_product_of_messages, + :before_product_of_two_messages, + :after_product_of_two_messages, + :after_product_of_messages, + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; callbacks = handler) + + messages = [ + Message(Normal(0, 1), false, false, nothing) + Message(Normal(0, 2), false, false, nothing) + Message(Normal(0, 3), false, false, nothing) + ] + + result = compute_product_of_messages(testvar, context, messages) + + # 3 messages => 1 before_all + 2*(before_two + after_two) + 1 after_all = 6 events + @test length(handler.events) == 6 + @test handler.events[1].event === :before_product_of_messages + @test handler.events[2].event === :before_product_of_two_messages + @test handler.events[3].event === :after_product_of_two_messages + @test handler.events[4].event === :before_product_of_two_messages + @test handler.events[5].event === :after_product_of_two_messages + @test handler.events[6].event === :after_product_of_messages + + # The after_product_of_messages result should match the final result + @test getdata(handler.events[6].args[4]) ≈ getdata(result) + end +end diff --git a/test/nodes/clusters_tests.jl b/test/nodes/clusters_tests.jl index 461fcd081..3e4c8a381 100644 --- a/test/nodes/clusters_tests.jl +++ b/test/nodes/clusters_tests.jl @@ -241,7 +241,7 @@ end ) options = FactorNodeActivationOptions( - nothing, nothing, nothing, nothing, nothing, nothing + nothing, nothing, nothing, nothing, nothing, nothing, nothing ) @test length(getmarginals(getlocalclusters(node))) === 1 @@ -269,7 +269,7 @@ end ) options = FactorNodeActivationOptions( - nothing, nothing, nothing, nothing, nothing, nothing + nothing, nothing, nothing, nothing, nothing, nothing, nothing ) @test length(getmarginals(getlocalclusters(node))) === 2 @@ -299,7 +299,7 @@ end ) options = FactorNodeActivationOptions( - nothing, nothing, nothing, nothing, nothing, nothing + nothing, nothing, nothing, nothing, nothing, nothing, nothing ) @test length(getmarginals(getlocalclusters(node))) === 2 @@ -329,7 +329,7 @@ end ) options = FactorNodeActivationOptions( - nothing, nothing, nothing, nothing, nothing, nothing + nothing, nothing, nothing, nothing, nothing, nothing, nothing ) @test length(getmarginals(getlocalclusters(node))) === 2 diff --git a/test/nodes/dependencies_tests.jl b/test/nodes/dependencies_tests.jl index 3f8c4232f..13aeef918 100644 --- a/test/nodes/dependencies_tests.jl +++ b/test/nodes/dependencies_tests.jl @@ -640,7 +640,7 @@ end @testset "use_a metadata results in CustomDependencyA" begin options_a = FactorNodeActivationOptions( - :use_a, nothing, nothing, nothing, AsapScheduler(), nothing + :use_a, nothing, nothing, nothing, AsapScheduler(), nothing, nothing ) deps = collect_functional_dependencies(CustomMetaNode, options_a) @test deps isa CustomDependencyA @@ -648,7 +648,7 @@ end @testset "use_b metadata results in CustomDependencyB" begin options_b = FactorNodeActivationOptions( - :use_b, nothing, nothing, nothing, AsapScheduler(), nothing + :use_b, nothing, nothing, nothing, AsapScheduler(), nothing, nothing ) deps = collect_functional_dependencies(CustomMetaNode, options_b) @test deps isa CustomDependencyB @@ -656,7 +656,13 @@ end @testset "no metadata falls back to default dependencies" begin options_default = FactorNodeActivationOptions( - nothing, nothing, nothing, nothing, AsapScheduler(), nothing + nothing, + nothing, + nothing, + nothing, + AsapScheduler(), + nothing, + nothing, ) deps = collect_functional_dependencies(CustomMetaNode, options_default) @test deps isa DefaultFunctionalDependencies @@ -670,7 +676,7 @@ end ((1,),), ) options_a = FactorNodeActivationOptions( - :use_a, nothing, nothing, nothing, AsapScheduler(), nothing + :use_a, nothing, nothing, nothing, AsapScheduler(), nothing, nothing ) deps_a = collect_functional_dependencies(CustomMetaNode, options_a) activate!(node_a, options_a) @@ -689,7 +695,7 @@ end ((1,),), ) options_b = FactorNodeActivationOptions( - :use_b, nothing, nothing, nothing, AsapScheduler(), nothing + :use_b, nothing, nothing, nothing, AsapScheduler(), nothing, nothing ) deps_b = collect_functional_dependencies(CustomMetaNode, options_b) activate!(node_b, options_b) diff --git a/test/nodes/predefined/mixture_tests.jl b/test/nodes/predefined/mixture_tests.jl index d1deaa736..97967b2f2 100644 --- a/test/nodes/predefined/mixture_tests.jl +++ b/test/nodes/predefined/mixture_tests.jl @@ -21,7 +21,6 @@ interfaces, sdtype, factornode, - FactorNodeActivationOptions, activate! # Common interfaces and factorizations used by both test groups diff --git a/test/variables/random_tests.jl b/test/variables/random_tests.jl index b298d62fe..389744cef 100644 --- a/test/variables/random_tests.jl +++ b/test/variables/random_tests.jl @@ -31,7 +31,9 @@ end @testitem "RandomVariable: getmarginal" begin import ReactiveMP: MessageObservable, + MessageProductContext, create_messagein!, + compute_product_of_messages, messagein, degree, activate!, @@ -41,8 +43,8 @@ end include("../testutilities.jl") - message_prod_fn = (msgs) -> error("Messages should not be called here") - marginal_prod_fn = (msgs) -> mgl(sum(getdata.(msgs))) + message_prod_fold = (variable, context, msgs) -> error("Messages should not be called here") + marginal_prod_fold = (variable, context, msgs) -> msg(sum(getdata.(msgs))) for d in 1:5:100 let var = randomvar() messageins = map(1:d) do _ @@ -55,20 +57,22 @@ end activate!( var, RandomVariableActivationOptions( - AsapScheduler(), message_prod_fn, marginal_prod_fn + AsapScheduler(), + MessageProductContext(fold_strategy = message_prod_fold), + MessageProductContext(fold_strategy = marginal_prod_fold), ), ) messages = map(msg, rand(d)) - marginal_expected = marginal_prod_fn(messages) + marginal_expected = mgl(sum(getdata.(messages))) marginal_result = check_stream_updated_once(getmarginal(var)) do foreach(zip(messageins, messages)) do (messagein, message) next!(messagein, message) end end - # We check the `getdata` here approximatelly because the `marginal_prod_fn` can rearrange + # We check the `getdata` here approximatelly because the `marginal_prod_fn` can rearrange # the messages under the hood that introduces minor numerical differences @test getdata(marginal_result) ≈ getdata(marginal_expected) end @@ -78,7 +82,9 @@ end @testitem "RandomVariable: messageout" begin import ReactiveMP: MessageObservable, + MessageProductContext, create_messagein!, + compute_product_of_messages, messagein, degree, activate!, @@ -88,8 +94,8 @@ end include("../testutilities.jl") - message_prod_fn = (msgs) -> msg(sum(filter(!ismissing, getdata.(msgs)))) - marginal_prod_fn = (msgs) -> error("Marginal should not be called here") + message_prod_fold = (variable, context, msgs) -> msg(sum(filter(!ismissing, getdata.(msgs)))) + marginal_prod_fold = (variable, context, msgs) -> error("Marginal should not be called here") # We start from `2` because `1` is not a valid degree for a random variable for d in 2:5:100, k in 1:d @@ -104,26 +110,155 @@ end activate!( var, RandomVariableActivationOptions( - AsapScheduler(), message_prod_fn, marginal_prod_fn + AsapScheduler(), + MessageProductContext(fold_strategy = message_prod_fold), + MessageProductContext(fold_strategy = marginal_prod_fold), ), ) messages = map(msg, rand(d)) # the outbound message is the result of multiplication of `n - 1` messages excluding index `k` - kmessage_expected = message_prod_fn(collect(skipindex(messages, k))) + kmessage_expected = msg(sum(filter(!ismissing, getdata.(collect(skipindex(messages, k)))))) kmessage_result = check_stream_updated_once(messageout(var, k)) do foreach(zip(messageins, messages)) do (messagein, message) next!(messagein, message) end end - # We check the `getdata` here approximatelly because the `message_prod_fn` can rearrange + # We check the `getdata` here approximatelly because the `message_prod_fn` can rearrange # the messages under the hood that introduces minor numerical differences @test getdata(kmessage_result) ≈ getdata(kmessage_expected) end end end +@testitem "RandomVariable: before/after marginal computation callbacks" begin + import ReactiveMP: + MessageObservable, + MessageProductContext, + RandomVariableActivationOptions, + AbstractMessage, + create_messagein!, + activate!, + connect!, + getdata + + import Rocket: Subject, next! + + include("../testutilities.jl") + + struct MarginalCallbackHandler + listen_to::Tuple + events + end + + function ReactiveMP.invoke_callback( + handler::MarginalCallbackHandler, ::Val{E}, args... + ) where {E} + E ∈ handler.listen_to && push!(handler.events, (event = E, args = args)) + end + + @testset "Fires before and after marginal computation with 3 messages" begin + listen_to = (:before_marginal_computation, :after_marginal_computation) + handler = MarginalCallbackHandler(listen_to, []) + marginal_context = MessageProductContext( + fold_strategy = (variable, context, msgs) -> msg(sum(getdata.(msgs))), + callbacks = handler, + ) + + var = randomvar() + + messageins = map(1:3) do _ + s = Subject(AbstractMessage) + m, i = create_messagein!(var) + connect!(m, s) + return s + end + + activate!( + var, + RandomVariableActivationOptions( + AsapScheduler(), + MessageProductContext(), + marginal_context, + ), + ) + + messages = [msg(1.0), msg(2.0), msg(3.0)] + + marginal_result = check_stream_updated_once(getmarginal(var)) do + foreach(zip(messageins, messages)) do (messagein, message) + next!(messagein, message) + end + end + + # sum(1.0 + 2.0 + 3.0) = 6.0 + @test getdata(marginal_result) ≈ 6.0 + + @test length(handler.events) == 2 + + # Before: variable, context, messages + @test handler.events[1].event === :before_marginal_computation + @test handler.events[1].args[1] === var + @test handler.events[1].args[2] === marginal_context + + # After: variable, context, messages, result + @test handler.events[2].event === :after_marginal_computation + @test handler.events[2].args[1] === var + @test handler.events[2].args[2] === marginal_context + @test length(handler.events[2].args[3]) == 3 + @test getdata(handler.events[2].args[4]) ≈ 6.0 + end + + @testset "Fires before and after marginal computation with 2 messages" begin + listen_to = (:before_marginal_computation, :after_marginal_computation) + handler = MarginalCallbackHandler(listen_to, []) + marginal_context = MessageProductContext( + fold_strategy = (variable, context, msgs) -> msg(sum(getdata.(msgs))), + callbacks = handler, + ) + + var = randomvar() + + messageins = map(1:2) do _ + s = Subject(AbstractMessage) + m, i = create_messagein!(var) + connect!(m, s) + return s + end + + activate!( + var, + RandomVariableActivationOptions( + AsapScheduler(), + MessageProductContext(), + marginal_context, + ), + ) + + messages = [msg(10.0), msg(20.0)] + + marginal_result = check_stream_updated_once(getmarginal(var)) do + foreach(zip(messageins, messages)) do (messagein, message) + next!(messagein, message) + end + end + + # sum(10.0 + 20.0) = 30.0 + @test getdata(marginal_result) ≈ 30.0 + + @test length(handler.events) == 2 + + @test handler.events[1].event === :before_marginal_computation + @test handler.events[1].args[1] === var + + @test handler.events[2].event === :after_marginal_computation + @test handler.events[2].args[1] === var + @test length(handler.events[2].args[3]) == 2 + @test getdata(handler.events[2].args[4]) ≈ 30.0 + end +end + @testitem "RandomVariable: activate! - zero or less than one inbound messages should throw" begin import ReactiveMP: RandomVariableActivationOptions, activate!, messageout