Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b5cc69c
feat: add event handler system for message passing procedure
bvdmitri Mar 5, 2026
ea3b6dc
type in docstring
bvdmitri Mar 5, 2026
eea47c8
inject the event handler in the factor node activation options
bvdmitri Mar 9, 2026
74b57a1
make format
bvdmitri Mar 9, 2026
3babdbe
document the existing events better
bvdmitri Mar 9, 2026
083f798
rename event handler to callbacks to better match RxInfer
bvdmitri Mar 9, 2026
4488d79
update docs
bvdmitri Mar 9, 2026
eedd31d
use make format
bvdmitri Mar 9, 2026
32ff1b6
fix typo
bvdmitri Mar 9, 2026
d6aafa6
allow to merge callbacks
bvdmitri Mar 9, 2026
6b6a1d4
allow reduce the result of the callbacks
bvdmitri Mar 9, 2026
b77cae2
add per-event callback reducer
bvdmitri Mar 10, 2026
4a9e28a
make format
bvdmitri Mar 10, 2026
99fc25e
remove test method
bvdmitri Mar 10, 2026
9bd5c18
start reimplement of the product of two messages
bvdmitri Mar 16, 2026
b264145
fix initial integration with RxInfer.jl
bvdmitri Mar 16, 2026
14c1925
fix Aqua tests
bvdmitri Mar 16, 2026
483c59a
caught a small bug in RxInfer tests
bvdmitri Mar 16, 2026
1ee727d
caught another bug from RxInfer.jl tests
bvdmitri Mar 16, 2026
d24879b
Merge branch 'main' into callbacks
bvdmitri Mar 19, 2026
629295f
Merge branch 'main' into callbacks
bvdmitri Mar 19, 2026
42e92c1
merge stricter formatting
bvdmitri Mar 19, 2026
41d6032
Refactor the variables, add docstrings, add labels
bvdmitri Mar 19, 2026
3f6d2d2
fix documentation build
bvdmitri Mar 19, 2026
c65d2c4
2prev
bvdmitri Mar 19, 2026
2d38c63
add before/after product of two messages callbacks
bvdmitri Mar 19, 2026
bf495be
fix failing tests
bvdmitri Mar 19, 2026
5a39369
add before/after product of messages callbacks. add before/after form…
bvdmitri Mar 19, 2026
609f466
add before/after marginal compute callbacks
bvdmitri Mar 19, 2026
aec5de0
add error hint in case of wrong passed callbacks
bvdmitri Mar 19, 2026
1c08d09
fix test
bvdmitri Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ makedocs(
"Introduction" => "index.md",
"Library" => [
"Factor nodes" => "lib/nodes.md",
"Variables" => "lib/variables.md",
"Messages" => "lib/message.md",
"Marginals" => "lib/marginal.md",
"Message update rules" => "lib/rules.md",
"Callbacks" => "lib/callbacks.md",
"Helper utils" => "lib/helpers.md",
"Algebra utils" => "lib/algebra.md",
"Specific factor nodes" => [
Expand Down
27 changes: 27 additions & 0 deletions docs/src/lib/callbacks.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Callbacks in the Message Passing Procedure

ReactiveMP provides a way to "hook" into the message passing procedure and listen to various events
via "callbacks". This can be useful, for example, to debug messages or monitor the order of computations.

```@docs
ReactiveMP.invoke_callback
ReactiveMP.merge_callbacks
ReactiveMP.MergedCallbacks
```

## All defined events

Here is the list of predefined event types, to which a custom callback handler can react to.

```@docs
ReactiveMP.BeforeMessageRuleCallback
ReactiveMP.AfterMessageRuleCallback
ReactiveMP.BeforeProductOfTwoMessages
ReactiveMP.AfterProductOfTwoMessages
ReactiveMP.BeforeProductOfMessages
ReactiveMP.AfterProductOfMessages
ReactiveMP.BeforeFormConstraintApplied
ReactiveMP.AfterFormConstraintApplied
ReactiveMP.BeforeMarginalComputation
ReactiveMP.AfterMarginalComputation
```
14 changes: 8 additions & 6 deletions docs/src/lib/message.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,15 @@ is_clamped(message), is_initial(message)
### [Product of messages](@id lib-messages-product)

In message passing framework, in order to compute a posterior we must compute a normalized product of two messages.
For this purpose the `ReactiveMP.jl` uses the `multiply_messages` function, which internally uses the `prod` function
defined in `BayesBase.jl` with various product strategies. We refer an interested reader to the documentation of the
`BayesBase.jl` package for more information.
For this purpose the `ReactiveMP.jl` uses the [`ReactiveMP.MessageProductContext`](@ref) structure, together with the [`ReactiveMP.compute_product_of_messages`](@ref) and [`ReactiveMP.compute_product_of_two_messages`](@ref) functions. Both functions accept a [`ReactiveMP.AbstractVariable`](@ref) as the first argument to identify which variable the product is being computed for — this is useful for callbacks (e.g. [`ReactiveMP.BeforeProductOfTwoMessages`](@ref)). The [`ReactiveMP.compute_product_of_two_messages`](@ref) function internally uses the `prod` function
defined in `BayesBase.jl` with various product strategies. We refer an interested reader to the documentation of the `BayesBase.jl` package for more information.

```@docs
ReactiveMP.multiply_messages
ReactiveMP.messages_prod_fn
ReactiveMP.MessageProductContext
ReactiveMP.compute_product_of_two_messages
ReactiveMP.compute_product_of_messages
ReactiveMP.MessagesProductFromLeftToRight
ReactiveMP.MessagesProductFromRightToLeft
```


Expand All @@ -95,4 +97,4 @@ A *message mapping* defines how messages are transformed or mapped during the pr

```@docs
ReactiveMP.MessageMapping
```
```
21 changes: 0 additions & 21 deletions docs/src/lib/nodes.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,6 @@ ReactiveMP.RequireMarginalFunctionalDependencies
ReactiveMP.RequireEverythingFunctionalDependencies
```

## [Customizing Dependencies with Metadata](@id lib-node-metadata-dependencies)

The functional dependencies of a node can be customized at runtime using options during node activation. This allows for runtime customization of the functional dependencies, e.g. to test different message passing schemes or implement specialized behavior for specific instances of a node type:

```julia
# Define custom dependencies based on metadata
function ReactiveMP.collect_functional_dependencies(::Type{MyNode}, options::FactorNodeActivationOptions)
if some_condition(options) # a user can specify dependencies based, for example, on metadata
return CustomDependencies()
end
# Fall back to default dependencies
return ReactiveMP.collect_functional_dependencies(MyNode, getdependecies(options))
end

# Use custom dependencies during activation
node = factornode(MyNode, ...)
activate!(node, FactorNodeActivationOptions(:custom_behavior, ...))
```

This feature is particularly useful for testing different message passing schemes or implementing specialized behavior for specific instances of a node type.

## [Node traits](@id lib-node-traits)

Each factor node has to define the [`ReactiveMP.is_predefined_node`](@ref) trait function and to specify a [`ReactiveMP.PredefinedNodeFunctionalForm`](@ref)
Expand Down
36 changes: 36 additions & 0 deletions docs/src/lib/variables.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

# [Variables](@id lib-variables)

Variables are fundamental building blocks of a factor graph. Each variable represents either a latent quantity to be inferred, an observed data point, or a fixed constant. All variable types are subtypes of [`ReactiveMP.AbstractVariable`](@ref).

```@docs
ReactiveMP.AbstractVariable
```

## [Random variables](@id lib-variables-random)

Random variables represent latent (unobserved) quantities in the model. During inference, messages flow through them to update the marginal belief.

```@docs
ReactiveMP.RandomVariable
ReactiveMP.randomvar
```

## [Data variables](@id lib-variables-data)

Data variables represent observed quantities. Their value is not fixed at creation time and can be updated later via [`update!`](@ref).

```@docs
ReactiveMP.DataVariable
ReactiveMP.datavar
ReactiveMP.update!
```

## [Constant variables](@id lib-variables-constant)

Constant variables hold a fixed value, wrapped in a `PointMass` distribution. Messages from constant variables are always marked as clamped.

```@docs
ReactiveMP.ConstVariable
ReactiveMP.constvar
```
8 changes: 8 additions & 0 deletions ext/ReactiveMPProjectionExt/layout/cvi_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ function deltafn_apply_layout(
scheduler,
addons,
rulefallback,
callbacks,
)
return deltafn_apply_layout(
DeltaFnDefaultRuleLayout(),
Expand All @@ -62,6 +63,7 @@ function deltafn_apply_layout(
scheduler,
addons,
rulefallback,
callbacks,
)
end

Expand All @@ -75,6 +77,7 @@ function deltafn_apply_layout(
scheduler,
addons,
rulefallback,
callbacks,
)
return deltafn_apply_layout(
DeltaFnDefaultRuleLayout(),
Expand All @@ -85,6 +88,7 @@ function deltafn_apply_layout(
scheduler,
addons,
rulefallback,
callbacks,
)
end

Expand All @@ -98,6 +102,7 @@ function deltafn_apply_layout(
scheduler,
addons,
rulefallback,
callbacks,
)
let interface = factornode.out
msgs_names = Val{(:out,)}()
Expand Down Expand Up @@ -125,6 +130,7 @@ function deltafn_apply_layout(
addons,
factornode,
rulefallback,
callbacks,
)
(dependencies) -> DeferredMessage(
dependencies[1], dependencies[2], messagemap
Expand Down Expand Up @@ -152,6 +158,7 @@ function deltafn_apply_layout(
scheduler,
addons,
rulefallback,
callbacks,
)
return deltafn_apply_layout(
DeltaFnDefaultRuleLayout(),
Expand All @@ -162,5 +169,6 @@ function deltafn_apply_layout(
scheduler,
addons,
rulefallback,
callbacks,
)
end
20 changes: 19 additions & 1 deletion src/ReactiveMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ include("helpers/algebra/standard_basis_vector.jl")

include("constraints/form.jl")

include("callbacks.jl")
include("variable.jl")
include("message.jl")
include("marginal.jl")
include("addons.jl")
Expand Down Expand Up @@ -76,7 +78,7 @@ include("approximations/cvi_projection.jl")
# Equality node is a special case and needs to be included before random variable implementation
include("nodes/equality.jl")

include("variables/variable.jl")
include("variables/generic.jl")
include("variables/random.jl")
include("variables/constant.jl")
include("variables/data.jl")
Expand Down Expand Up @@ -113,6 +115,22 @@ function __init__()
"""
println(io, errmsg)
end
if exc.f === ReactiveMP.invoke_callback && length(argtypes) >= 2
event_type = length(argtypes) >= 2 ? argtypes[2] : "unknown"
errmsg = """

`ReactiveMP.invoke_callback` was called with a callback handler of type `$(argtypes[1])` for event `$(event_type)`, but no matching method was found. This can happen if:

1. You implemented a custom callback handler but forgot to define `invoke_callback` for this specific event type.
Make sure your handler has a method like:
ReactiveMP.invoke_callback(::$(argtypes[1]), ::$(event_type), args...) = ...

2. You meant to pass a `NamedTuple` as the callbacks handler but forgot the trailing comma.
In Julia, `(key = value)` is parsed as a plain assignment, not a NamedTuple.
Use `(key = value,)` (with a trailing comma) instead.
"""
println(io, errmsg)
end
end
end

Expand Down
30 changes: 15 additions & 15 deletions src/approximations/unscented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ function statistic_estimation(
# Compute `C_tilde` only if `C === true`
@inbounds C_tilde = if C
reshape(
sum(
wc * (xi - m) * (yi - m_tilde) for
(wc, xi, yi) in zip(weights_c, sigma_points, g_sigma)
),
1,
d_out,
)
sum(
wc * (xi - m) * (yi - m_tilde) for
(wc, xi, yi) in zip(weights_c, sigma_points, g_sigma)
),
1,
d_out,
)
else
nothing
end
Expand Down Expand Up @@ -224,10 +224,10 @@ function unscented_statistics(
# Compute `C_tilde` only if `C === true`
@inbounds C_tilde = if C
sum(
weights_c[k + 1] *
(sigma_points[k + 1] - m) *
(g_sigma[k + 1] - m_tilde)' for k in 0:(2d)
)
weights_c[k + 1] *
(sigma_points[k + 1] - m) *
(g_sigma[k + 1] - m_tilde)' for k in 0:(2d)
)
else
nothing
end
Expand Down Expand Up @@ -258,10 +258,10 @@ function unscented_statistics(
# Compute `C_tilde` only if `C === true`
@inbounds C_tilde = if C
sum(
weights_c[k + 1] *
(sigma_points[k + 1] - m) *
(g_sigma[k + 1] - m_tilde)' for k in 0:(2d)
)
weights_c[k + 1] *
(sigma_points[k + 1] - m) *
(g_sigma[k + 1] - m_tilde)' for k in 0:(2d)
)
else
nothing
end
Expand Down
Loading
Loading