Skip to content

Commit a40573e

Browse files
authored
Implement event handler system for message passing and callbacks (#587)
* feat: add event handler system for message passing procedure Introduce a new event handler mechanism that allows users to hook into the message passing procedure via `Event`, `handle_event`, and `broadcast_event` functions. Events are broadcast before and after message rule calls, enabling debugging and monitoring capabilities. Key changes: - Add `event_handler.jl` with event types and handler interfaces - Extend `MessageMapping` to include an event handler parameter - Support both custom handler types and NamedTuple-based handlers - Add documentation for the events system - Include comprehensive tests for event handler functionality * type in docstring * inject the event handler in the factor node activation options * make format * document the existing events better * rename event handler to callbacks to better match RxInfer * update docs * use make format * fix typo * allow to merge callbacks * allow reduce the result of the callbacks * add per-event callback reducer * make format * remove test method * start reimplement of the product of two messages this is required to inject the callbacks properly, plus we do the breaking release already. Could break (and improve) more things as well then * fix initial integration with RxInfer.jl * fix Aqua tests * caught a small bug in RxInfer tests * caught another bug from RxInfer.jl tests * merge stricter formatting * Refactor the variables, add docstrings, add labels - `compute_product_of_messages` now accepts the `AbstractVariable`, makes it easier to identify the variable inside the callback * fix documentation build * 2prev * add before/after product of two messages callbacks * fix failing tests * add before/after product of messages callbacks. add before/after form constraints * add before/after marginal compute callbacks * add error hint in case of wrong passed callbacks * fix test * temporary fix for the logscale switch rule * support Dict as callback handler * fix warnings in the tests * make format * refactor to use an abstract Event{E} structure * add new method for event_name * use mutable fields instead * update CHANGELOG
1 parent ec85921 commit a40573e

31 files changed

Lines changed: 2115 additions & 237 deletions

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- Callback/event system for hooking into message passing steps (rule calls, message products, form constraints, marginal computation)
12+
- `MessageProductContext` struct to bundle product computation settings and callbacks
13+
- Labels for variables (`RandomVariable`, `ConstVariable`, `DataVariable`)
14+
- Docstrings for variable types, form constraints, and related functions
15+
- Documentation page for callbacks
16+
- `MethodError` hint for mismatched `handle_event` signatures
17+
1018
### Changed
1119
- Switched from `ReTestItems` to `TestItemRunner` for tests ([#584](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/584))
1220
- Made formatting checks stricter
21+
- Renamed `variables/variable.jl` to `variables/generic.jl`
22+
- Replaced hardcoded `DefaultMessageProdFn`/`DefaultMarginalProdFn` with `MessageProductContext`
1323

1424
## [5.6.6] - 2026-03-13
1525

docs/make.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ makedocs(
1414
"Introduction" => "index.md",
1515
"Library" => [
1616
"Factor nodes" => "lib/nodes.md",
17+
"Variables" => "lib/variables.md",
1718
"Messages" => "lib/message.md",
1819
"Marginals" => "lib/marginal.md",
1920
"Message update rules" => "lib/rules.md",
21+
"Callbacks" => "lib/callbacks.md",
2022
"Helper utils" => "lib/helpers.md",
2123
"Algebra utils" => "lib/algebra.md",
2224
"Specific factor nodes" => [

docs/src/lib/callbacks.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Callbacks in the Message Passing Procedure
2+
3+
ReactiveMP provides a way to "hook" into the message passing procedure and listen to various events
4+
via "callbacks". This can be useful, for example, to debug messages or monitor the order of computations.
5+
6+
```@docs
7+
ReactiveMP.Event
8+
ReactiveMP.event_name
9+
ReactiveMP.handle_event
10+
ReactiveMP.invoke_callback
11+
ReactiveMP.merge_callbacks
12+
ReactiveMP.MergedCallbacks
13+
```
14+
15+
## Event naming convention
16+
17+
Every event in ReactiveMP is a concrete subtype of [`ReactiveMP.Event{E}`](@ref) where `E` is a `Symbol` identifying the event.
18+
The naming convention is straightforward: for an event identified by the symbol `:event_name`, the corresponding struct is called `EventNameEvent`.
19+
For example:
20+
21+
| Symbol | Struct |
22+
|--------|--------|
23+
| `:before_message_rule_call` | [`ReactiveMP.BeforeMessageRuleCallEvent`](@ref) |
24+
| `:after_product_of_two_messages` | [`ReactiveMP.AfterProductOfTwoMessagesEvent`](@ref) |
25+
| `:before_form_constraint_applied` | [`ReactiveMP.BeforeFormConstraintAppliedEvent`](@ref) |
26+
27+
Each event struct carries the relevant data as fields, so you can inspect what happened during inference.
28+
You can use [`ReactiveMP.event_name`](@ref) to retrieve the symbol from any event type:
29+
30+
```@example callbacks
31+
using ReactiveMP #hide
32+
ReactiveMP.event_name(ReactiveMP.BeforeProductOfTwoMessagesEvent)
33+
```
34+
35+
To see which fields an event carries, use the standard Julia introspection:
36+
37+
```julia
38+
julia> ?ReactiveMP.BeforeProductOfTwoMessagesEvent
39+
```
40+
41+
## All defined events
42+
43+
Here is the list of predefined event types, to which a custom callback handler can react to.
44+
45+
```@docs
46+
ReactiveMP.BeforeMessageRuleCallEvent
47+
ReactiveMP.AfterMessageRuleCallEvent
48+
ReactiveMP.BeforeProductOfTwoMessagesEvent
49+
ReactiveMP.AfterProductOfTwoMessagesEvent
50+
ReactiveMP.BeforeProductOfMessagesEvent
51+
ReactiveMP.AfterProductOfMessagesEvent
52+
ReactiveMP.BeforeFormConstraintAppliedEvent
53+
ReactiveMP.AfterFormConstraintAppliedEvent
54+
ReactiveMP.BeforeMarginalComputationEvent
55+
ReactiveMP.AfterMarginalComputationEvent
56+
```

docs/src/lib/message.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,15 @@ is_clamped(message), is_initial(message)
7373
### [Product of messages](@id lib-messages-product)
7474

7575
In message passing framework, in order to compute a posterior we must compute a normalized product of two messages.
76-
For this purpose the `ReactiveMP.jl` uses the `multiply_messages` function, which internally uses the `prod` function
77-
defined in `BayesBase.jl` with various product strategies. We refer an interested reader to the documentation of the
78-
`BayesBase.jl` package for more information.
76+
For this purpose the `ReactiveMP.jl` uses the [`ReactiveMP.MessageProductContext`](@ref) structure, together with the [`ReactiveMP.compute_product_of_messages`](@ref) and [`ReactiveMP.compute_product_of_two_messages`](@ref) functions. Both functions accept a [`ReactiveMP.AbstractVariable`](@ref) as the first argument to identify which variable the product is being computed for — this is useful for callbacks (e.g. [`ReactiveMP.BeforeProductOfTwoMessagesEvent`](@ref)). The [`ReactiveMP.compute_product_of_two_messages`](@ref) function internally uses the `prod` function
77+
defined in `BayesBase.jl` with various product strategies. We refer an interested reader to the documentation of the `BayesBase.jl` package for more information.
7978

8079
```@docs
81-
ReactiveMP.multiply_messages
82-
ReactiveMP.messages_prod_fn
80+
ReactiveMP.MessageProductContext
81+
ReactiveMP.compute_product_of_two_messages
82+
ReactiveMP.compute_product_of_messages
83+
ReactiveMP.MessagesProductFromLeftToRight
84+
ReactiveMP.MessagesProductFromRightToLeft
8385
```
8486

8587

@@ -95,4 +97,4 @@ A *message mapping* defines how messages are transformed or mapped during the pr
9597

9698
```@docs
9799
ReactiveMP.MessageMapping
98-
```
100+
```

docs/src/lib/nodes.md

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,6 @@ ReactiveMP.RequireMarginalFunctionalDependencies
113113
ReactiveMP.RequireEverythingFunctionalDependencies
114114
```
115115

116-
## [Customizing Dependencies with Metadata](@id lib-node-metadata-dependencies)
117-
118-
The functional dependencies of a node can be customized at runtime using options during node activation. This allows for runtime customization of the functional dependencies, e.g. to test different message passing schemes or implement specialized behavior for specific instances of a node type:
119-
120-
```julia
121-
# Define custom dependencies based on metadata
122-
function ReactiveMP.collect_functional_dependencies(::Type{MyNode}, options::FactorNodeActivationOptions)
123-
if some_condition(options) # a user can specify dependencies based, for example, on metadata
124-
return CustomDependencies()
125-
end
126-
# Fall back to default dependencies
127-
return ReactiveMP.collect_functional_dependencies(MyNode, getdependecies(options))
128-
end
129-
130-
# Use custom dependencies during activation
131-
node = factornode(MyNode, ...)
132-
activate!(node, FactorNodeActivationOptions(:custom_behavior, ...))
133-
```
134-
135-
This feature is particularly useful for testing different message passing schemes or implementing specialized behavior for specific instances of a node type.
136-
137116
## [Node traits](@id lib-node-traits)
138117

139118
Each factor node has to define the [`ReactiveMP.is_predefined_node`](@ref) trait function and to specify a [`ReactiveMP.PredefinedNodeFunctionalForm`](@ref)

docs/src/lib/variables.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
# [Variables](@id lib-variables)
3+
4+
Variables are fundamental building blocks of a factor graph. Each variable represents either a latent quantity to be inferred, an observed data point, or a fixed constant. All variable types are subtypes of [`ReactiveMP.AbstractVariable`](@ref).
5+
6+
```@docs
7+
ReactiveMP.AbstractVariable
8+
```
9+
10+
## [Random variables](@id lib-variables-random)
11+
12+
Random variables represent latent (unobserved) quantities in the model. During inference, messages flow through them to update the marginal belief.
13+
14+
```@docs
15+
ReactiveMP.RandomVariable
16+
ReactiveMP.randomvar
17+
```
18+
19+
## [Data variables](@id lib-variables-data)
20+
21+
Data variables represent observed quantities. Their value is not fixed at creation time and can be updated later via [`update!`](@ref).
22+
23+
```@docs
24+
ReactiveMP.DataVariable
25+
ReactiveMP.datavar
26+
ReactiveMP.update!
27+
```
28+
29+
## [Constant variables](@id lib-variables-constant)
30+
31+
Constant variables hold a fixed value, wrapped in a `PointMass` distribution. Messages from constant variables are always marked as clamped.
32+
33+
```@docs
34+
ReactiveMP.ConstVariable
35+
ReactiveMP.constvar
36+
```

ext/ReactiveMPProjectionExt/layout/cvi_projection.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ function deltafn_apply_layout(
5252
scheduler,
5353
addons,
5454
rulefallback,
55+
callbacks,
5556
)
5657
return deltafn_apply_layout(
5758
DeltaFnDefaultRuleLayout(),
@@ -62,6 +63,7 @@ function deltafn_apply_layout(
6263
scheduler,
6364
addons,
6465
rulefallback,
66+
callbacks,
6567
)
6668
end
6769

@@ -75,6 +77,7 @@ function deltafn_apply_layout(
7577
scheduler,
7678
addons,
7779
rulefallback,
80+
callbacks,
7881
)
7982
return deltafn_apply_layout(
8083
DeltaFnDefaultRuleLayout(),
@@ -85,6 +88,7 @@ function deltafn_apply_layout(
8588
scheduler,
8689
addons,
8790
rulefallback,
91+
callbacks,
8892
)
8993
end
9094

@@ -98,6 +102,7 @@ function deltafn_apply_layout(
98102
scheduler,
99103
addons,
100104
rulefallback,
105+
callbacks,
101106
)
102107
let interface = factornode.out
103108
msgs_names = Val{(:out,)}()
@@ -125,6 +130,7 @@ function deltafn_apply_layout(
125130
addons,
126131
factornode,
127132
rulefallback,
133+
callbacks,
128134
)
129135
(dependencies) -> DeferredMessage(
130136
dependencies[1], dependencies[2], messagemap
@@ -152,6 +158,7 @@ function deltafn_apply_layout(
152158
scheduler,
153159
addons,
154160
rulefallback,
161+
callbacks,
155162
)
156163
return deltafn_apply_layout(
157164
DeltaFnDefaultRuleLayout(),
@@ -162,5 +169,6 @@ function deltafn_apply_layout(
162169
scheduler,
163170
addons,
164171
rulefallback,
172+
callbacks,
165173
)
166174
end

src/ReactiveMP.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ include("helpers/algebra/standard_basis_vector.jl")
2121

2222
include("constraints/form.jl")
2323

24+
include("callbacks.jl")
25+
include("variable.jl")
2426
include("message.jl")
2527
include("marginal.jl")
2628
include("addons.jl")
@@ -76,7 +78,7 @@ include("approximations/cvi_projection.jl")
7678
# Equality node is a special case and needs to be included before random variable implementation
7779
include("nodes/equality.jl")
7880

79-
include("variables/variable.jl")
81+
include("variables/generic.jl")
8082
include("variables/random.jl")
8183
include("variables/constant.jl")
8284
include("variables/data.jl")
@@ -113,6 +115,27 @@ function __init__()
113115
"""
114116
println(io, errmsg)
115117
end
118+
if exc.f === ReactiveMP.handle_event && length(argtypes) >= 2
119+
event_type = argtypes[2]
120+
event_hint = if event_type <: ReactiveMP.Event
121+
"Event{$(repr(ReactiveMP.event_name(event_type)))}"
122+
else
123+
string(event_type)
124+
end
125+
errmsg = """
126+
127+
`ReactiveMP.handle_event` was called with a callback handler of type `$(argtypes[1])` for event `$(event_type)`, but no matching method was found. This can happen if:
128+
129+
1. You implemented a custom callback handler but forgot to define `handle_event` for this specific event type.
130+
Make sure your handler has a method like:
131+
ReactiveMP.handle_event(::$(argtypes[1]), event::$(event_hint)) = ...
132+
133+
2. You meant to pass a `NamedTuple` as the callbacks handler but forgot the trailing comma.
134+
In Julia, `(key = value)` is parsed as a plain assignment, not a NamedTuple.
135+
Use `(key = value,)` (with a trailing comma) instead.
136+
"""
137+
println(io, errmsg)
138+
end
116139
end
117140
end
118141

src/approximations/unscented.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,13 @@ function statistic_estimation(
135135
# Compute `C_tilde` only if `C === true`
136136
@inbounds C_tilde = if C
137137
reshape(
138-
sum(
139-
wc * (xi - m) * (yi - m_tilde) for
140-
(wc, xi, yi) in zip(weights_c, sigma_points, g_sigma)
141-
),
142-
1,
143-
d_out,
144-
)
138+
sum(
139+
wc * (xi - m) * (yi - m_tilde) for
140+
(wc, xi, yi) in zip(weights_c, sigma_points, g_sigma)
141+
),
142+
1,
143+
d_out,
144+
)
145145
else
146146
nothing
147147
end
@@ -224,10 +224,10 @@ function unscented_statistics(
224224
# Compute `C_tilde` only if `C === true`
225225
@inbounds C_tilde = if C
226226
sum(
227-
weights_c[k + 1] *
228-
(sigma_points[k + 1] - m) *
229-
(g_sigma[k + 1] - m_tilde)' for k in 0:(2d)
230-
)
227+
weights_c[k + 1] *
228+
(sigma_points[k + 1] - m) *
229+
(g_sigma[k + 1] - m_tilde)' for k in 0:(2d)
230+
)
231231
else
232232
nothing
233233
end
@@ -258,10 +258,10 @@ function unscented_statistics(
258258
# Compute `C_tilde` only if `C === true`
259259
@inbounds C_tilde = if C
260260
sum(
261-
weights_c[k + 1] *
262-
(sigma_points[k + 1] - m) *
263-
(g_sigma[k + 1] - m_tilde)' for k in 0:(2d)
264-
)
261+
weights_c[k + 1] *
262+
(sigma_points[k + 1] - m) *
263+
(g_sigma[k + 1] - m_tilde)' for k in 0:(2d)
264+
)
265265
else
266266
nothing
267267
end

0 commit comments

Comments
 (0)