diff --git a/.github/workflows/Changelog.yml b/.github/workflows/Changelog.yml new file mode 100644 index 000000000..013b52cb2 --- /dev/null +++ b/.github/workflows/Changelog.yml @@ -0,0 +1,14 @@ +name: Check CHANGELOG.md +on: + pull_request: + types: [assigned, opened, synchronize, reopened, labeled, unlabeled] + branches: + - main +jobs: + check-changelog: + name: Check Changelog Action + runs-on: ubuntu-latest + steps: + - uses: tarides/changelog-check-action@v2 + with: + changelog: CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..61c2c0766 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,228 @@ +# Changelog + +All notable changes to ReactiveMP.jl will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [6.0.0] + +### Added +- `AbstractStreamPostprocessor` abstraction unifying the old pipeline stages and the per-node `scheduler` argument under a single concept that postprocesses outbound message streams, marginal streams, and score streams uniformly +- `postprocess_stream_of_outbound_messages`, `postprocess_stream_of_marginals`, `postprocess_stream_of_scores` entry points with `::Nothing` pass-through fallbacks +- `CompositeStreamPostprocessor` for chaining multiple postprocessors +- `ScheduleOnStreamPostprocessor` — direct successor of `ScheduleOnPipelineStage` plus the per-node scheduler, applies a Rocket.jl scheduler to all three stream kinds +- Marginal streams and score streams now go through stream postprocessors (previously only outbound message streams did) +- Documentation page for stream postprocessors +- Callback/event system for hooking into message passing steps (rule calls, message products, form constraints, marginal computation) +- `MessageProductContext` struct to bundle product computation settings and callbacks +- Labels for variables (`RandomVariable`, `ConstVariable`, `DataVariable`) +- Docstrings for variable types, form constraints, and related functions +- Documentation page for callbacks +- `MethodError` hint for mismatched `handle_event` signatures +- New annotations system: `AnnotationDict`, `AbstractAnnotations`, `LogScaleAnnotations`, `InputArgumentsAnnotations` +- `post_rule_annotations!` and `post_product_annotations!` callbacks for annotation processors +- `@logscale value` macro for setting log-scale annotations inside `@rule` bodies +- `getannotations` function for `Message` and `Marginal` +- Migration guide for v5 to v6 +- `skip_initial()`, `skip_clamped()`, `skip_clamped_and_initial()` filter operators replacing the `MarginalSkipStrategy` type hierarchy +- `new_observation!(datavar, value)` for pushing observed values into a `DataVariable` +- `get_stream_of_inbound_messages`, `get_stream_of_outbound_messages` accessors on `NodeInterface` and `IndexedNodeInterface` +- `get_stream_of_marginals`, `set_stream_of_marginals!` accessors on variables +- `get_stream_of_predictions`, `set_stream_of_predictions!` accessors on variables +- `set_initial_marginal!`, `set_initial_message!` for seeding variables before inference +- `create_new_stream_of_inbound_messages!` for allocating per-connection message streams +- Docstrings for `MessageObservable`, `MarginalObservable`, `FunctionalDependencies`, `collect_functional_dependencies`, `RandomVariableActivationOptions`, `DataVariableActivationOptions`, `FactorNodeActivationOptions`, and `activate!` methods +- Expanded documentation for variables (stream creation lifecycle per variable type), nodes (interfaces, activation), messages, and marginals + +### Changed +- `FactorNodeActivationOptions` lost its `pipeline` and `scheduler` positional fields and gained a single `postprocessor` field +- `RandomVariableActivationOptions` renamed its `scheduler` field to `stream_postprocessor`; the default is now `nothing` (no-op) instead of `AsapScheduler()` +- `getpipeline(options)` and `getscheduler(options)` replaced by `getpostprocessor(options)` +- `EqualityChain` renamed its `pipeline` field to `postprocessor` +- Switched from `ReTestItems` to `TestItemRunner` for tests ([#584](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/584)) +- Made formatting checks stricter +- Removed `variables/generic.jl`; generic variable interface moved into `variable.jl` +- Replaced hardcoded `DefaultMessageProdFn`/`DefaultMarginalProdFn` with `MessageProductContext` +- `Message{D, A}` → `Message{D}` (type parameter `A` removed) +- `Marginal{D, A}` → `Marginal{D}` (type parameter `A` removed) +- `Message` and `Marginal` now carry an `AnnotationDict` instead of a typed addons tuple +- Rules no longer return `(result, addons)` tuples — just the result +- `@call_rule` no longer supports `return_addons` option; use `annotations` keyword with `AnnotationDict` +- `MessageMapping.addons` field → `MessageMapping.annotations` +- `MessageProductContext` gained `annotations` field for product-time annotation processors +- `messagein(interface)` → `get_stream_of_inbound_messages(interface)` +- `messageout(interface)` → `get_stream_of_outbound_messages(interface)` +- `getmarginal(variable)` / `getmarginals` → `get_stream_of_marginals(variable)` +- `getprediction(variable)` / `getpredictions` → `get_stream_of_predictions(variable)` +- `setmarginal!(variable, value)` → `set_initial_marginal!(variable, value)` +- `setmessage!(variable, value)` → `set_initial_message!(variable, value)` +- `update!(datavar, value)` → `new_observation!(datavar, value)` + +### Removed +- `AbstractPipelineStage`, `EmptyPipelineStage`, `CompositePipelineStage`, `ScheduleOnPipelineStage`, `apply_pipeline_stage`, `collect_pipeline`, `+` composition — replaced by the `AbstractStreamPostprocessor` abstraction (see migration guide) +- `LoggerPipelineStage` — equivalent behaviour can be implemented via callbacks +- `AsyncPipelineStage` — use `ScheduleOnStreamPostprocessor(AsyncScheduler())` instead +- `DiscontinuePipelineStage` — was unused; implement a custom `AbstractStreamPostprocessor` if needed +- `schedule_updates(vars; pipeline_stage = ...)` — construct a `ScheduleOnStreamPostprocessor` and pass it through the activation options instead +- `getaddons` — use `getannotations` instead +- `getlogscale(::Message)`, `getlogscale(::Marginal)` — use `getlogscale(getannotations(...))` instead +- `getmemory`, `getmemoryaddon` — use `get_rule_input_arguments(getannotations(...))` instead +- `AddonLogScale` — replaced by `LogScaleAnnotations` (calling `AddonLogScale()` throws a descriptive error) +- `AddonMemory` — replaced by `InputArgumentsAnnotations` (calling `AddonMemory()` throws a descriptive error) +- `AddonDebug` — use callbacks instead +- `AbstractAddon`, `multiply_addons`, `@invokeaddon` +- `message_mapping_addons`, `message_mapping_addon` helper functions +- `MarginalSkipStrategy` abstract type and `SkipClamped`, `SkipInitial`, `SkipClampedAndInitial`, `IncludeAll` subtypes — use `skip_clamped()`, `skip_initial()`, `skip_clamped_and_initial()` filter operators instead +- `apply_skip_filter`, `as_marginal_observable` — no longer part of the public API +- `messagein`, `messageout` — use `get_stream_of_inbound_messages`, `get_stream_of_outbound_messages` +- `getmarginal`, `getmarginals`, `getprediction`, `getpredictions` — use `get_stream_of_marginals`, `get_stream_of_predictions` +- `setmarginal!`, `setmarginals!`, `setmessage!`, `setmessages!` — use `set_initial_marginal!`, `set_initial_message!` +- `update!` — use `new_observation!` +- `create_messagein!` — use `create_new_stream_of_inbound_messages!` + +## [5.6.6] - 2026-03-13 + +### Fixed +- Implemented effective rules with specialized dispatch for `MvNormalMeanScalePrecision` ([#579](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/579)) + +### Tests +- Added performance test for structured rule specialized for `MvNormalMeanScalePrecision` + +## [5.6.5] - 2026-02-02 + +### Added +- Implemented `MvNormalWishart` node and `out` rule ([#565](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/565)) +- Issue templates ([#558](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/558)) +- Rule interface name checking for `@rule`, `@marginalrule`, and `@average_energy` macros ([#545](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/545)) + +### Changed +- Removed `Requires` dependency (used for Julia <1.9, no longer supported) ([#564](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/564)) +- Removed vibe coded required fields from issue template ([#562](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/562)) +- Bumped compat for Optim to 2 ([#574](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/574)) +- Removed `Zygote` extension requirement for compatibility + +### Fixed +- Fixed documentation build ([#567](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/567)) + +### Performance +- Pre-computed double loops in CT model ([#571](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/571)) + +## [5.6.4] - 2025-11-18 + +### Fixed +- Fixed bug in average energy of `Uninformative` type ([#553](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/553)) +- Improved robustness of inverse precision matrix computation in `MvNormalMeanPrecision` rule ([#540](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/540)) + +### Documentation +- Added `MessageMapping` documentation ([#550](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/550)) + +## [5.6.3] - 2025-11-04 + +### Added +- Added `IntegrationTest.yml` workflow ([#525](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/525)) +- Implemented missing marginal rule for multiplication node ([#531](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/531)) + +### Changed +- Changed dispatch to `AbstractVector` to allow other vector implementations ([#536](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/536)) + +### Fixed +- Renamed softdot marginal rules test file to include it in test runs ([#535](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/535)) +- Avoided `Vararg` deprecation warnings ([#537](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/537)) +- Defined `AverageEnergy` for `Mixture` node with warning ([#546](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/546)) +- Allowed different numeric types for `GammaShapeLikelihood` constructor ([#544](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/544)) + +### Tests +- Improved code coverage with tests for BIFM, Autoregressive, Mixture, GammaMixture, Wishart, InverseWishart, DotProduct, Multiplication, and Uniform nodes ([#539](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/539)) + +## [5.6.2] - 2025-10-21 + +### Fixed +- Fixed `isonehot` to use approximate comparison for categorical rules ([#527](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/527)) + +## [5.6.1] - 2025-10-21 + +### Added +- Categorical rules: check if probability vector of `q_out` is a one-hot encoded vector ([#510](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/510)) +- Support for non-linear node (univariate -> multivariate) with Unscented transform ([#508](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/508)) + +### Changed +- Updated `ForwardDiff` to version 1 ([#521](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/521)) +- Skip Aqua.jl checks during selective test runs ([#523](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/523)) +- Makefile: fixed selective test runs with `test_args` argument ([#517](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/517)) + +### Fixed +- Updated documentation for `as_marginal` ([#516](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/516)) +- Updated test cases for `GammaShapeScale` + +## [5.6.0] - 2025-09-23 + +### Changed +- Use `MvNormalMeanScaleMatrixPrecision` from ExponentialFamily package ([#509](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/509)) + +## [5.5.12] - 2025-09-11 + +### Fixed +- Fixed `q_t1` dimensionality bug in delta node rules ([#504](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/504)) + +## [5.5.11] - 2025-09-10 + +### Added +- Implemented `MvNormalMeanScaleMatrixPrecision` rules ([#497](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/497)) + +## [5.5.10] - 2025-09-09 + +### Fixed +- Added new linearization method ([#500](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/500)) + +## [5.5.9] - 2025-08-14 + +### Changed +- Reverted "Don't check for proper in division of" ([#496](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/496)) + +## [5.5.8] - 2025-08-14 + +### Added +- Show meta suggestions in rule error printing ([#495](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/495)) + +### Changed +- Tightened dispatch for summation and other optimized rules ([#492](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/492), [#493](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/493)) + +### Fixed +- Don't check for proper in division of to accommodate inference ([#486](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/486)) + +## [5.5.7] - 2025-07-24 + +### Fixed +- Fixed infinite RxInfer documentation build by not using lazy string ([#490](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/490)) + +## [5.5.6] - 2025-07-23 + +### Fixed +- Fixed invalidations: removed bad `eltype` methods, `convert`, and `println` method ([#489](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/489)) + +## [5.5.5] - 2025-07-23 + +### Fixed +- Resolved Gaussian division with proper Multivariate vs Univariate handling ([#479](https://github.com/ReactiveBayes/ReactiveMP.jl/pull/479)) + +--- + +[Unreleased]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.6.6...HEAD +[5.6.6]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.6.5...v5.6.6 +[5.6.5]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.6.4...v5.6.5 +[5.6.4]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.6.3...v5.6.4 +[5.6.3]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.6.2...v5.6.3 +[5.6.2]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.6.1...v5.6.2 +[5.6.1]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.6.0...v5.6.1 +[5.6.0]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.5.12...v5.6.0 +[5.5.12]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.5.11...v5.5.12 +[5.5.11]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.5.10...v5.5.11 +[5.5.10]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.5.9...v5.5.10 +[5.5.9]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.5.8...v5.5.9 +[5.5.8]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.5.7...v5.5.8 +[5.5.7]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.5.6...v5.5.7 +[5.5.6]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.5.5...v5.5.6 +[5.5.5]: https://github.com/ReactiveBayes/ReactiveMP.jl/compare/v5.5.4...v5.5.5 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..47e3a467c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,38 @@ +# Contributing to ReactiveMP + +**ReactiveMP.jl** is a community-driven reactive message passing engine and we welcome contributions of all kinds! + +ReactiveMP is part of the [ReactiveBayes](https://github.com/ReactiveBayes) organization and the broader [RxInfer](https://github.com/ReactiveBayes/RxInfer.jl) ecosystem for reactive Bayesian inference in Julia. + +## Getting Started + +* Browse **beginner-friendly issues**: https://github.com/ReactiveBayes/ReactiveMP.jl/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22 +* Check out the [RxInfer Contributing Guide](https://docs.rxinfer.com/stable/contributing/guide) for general conventions shared across the ecosystem + +We welcome contributions such as: + +* Bug reports and fixes +* New message passing rules and nodes +* Documentation improvements +* Tests and performance improvements +* Feature suggestions + +## Development Workflow + +1. Fork the repository and create a feature branch +2. Make your changes and ensure tests pass locally with `make test` +3. Run `make format` to ensure consistent code formatting +4. Open a pull request against `main` — the CI will check tests, formatting, and that `CHANGELOG.md` has been updated + +## Contributing to the RxInfer Ecosystem + +ReactiveMP is one of several packages in the RxInfer ecosystem. Contributions to any of these projects are very welcome: + +- [RxInfer.jl](https://github.com/ReactiveBayes/RxInfer.jl) — the high-level inference package +- [GraphPPL.jl](https://github.com/ReactiveBayes/GraphPPL.jl) — probabilistic model specification +- [ExponentialFamily.jl](https://github.com/ReactiveBayes/ExponentialFamily.jl) — exponential family distributions +- Browse other packages at https://github.com/ReactiveBayes + +If you're unsure where to start or where a contribution belongs, feel free to open an issue or start a discussion. + +Thank you for helping improve ReactiveMP! diff --git a/Project.toml b/Project.toml index 118126a6f..02ca77db3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ReactiveMP" uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3" -version = "5.6.6" authors = ["Dmitry Bagaev ", "Albert Podusenko ", "Bart van Erp ", "Ismail Senoz "] +version = "6.0.0" [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" @@ -30,6 +30,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [weakdeps] ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b" @@ -68,6 +69,7 @@ StatsFuns = "1.3.0" TinyHugeNumbers = "1.0.0" Tullio = "0.3" TupleTools = "1.2.0" +UUIDs = "1" julia = "1.10" [extras] @@ -87,9 +89,9 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" -TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" [targets] test = ["Aqua", "TestItemRunner", "Test", "Pkg", "Logging", "InteractiveUtils", "Coverage", "Dates", "Distributed", "Documenter", "BenchmarkTools", "JET", "PkgBenchmark", "StableRNGs", "Optimisers", "DiffResults", "ExponentialFamilyProjection", "REPL", "Manopt"] diff --git a/README.md b/README.md index a0daecbd3..422655c98 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,51 @@ -# ReactiveMP.jl + -| **Documentation** | **Build Status** | **Coverage** | **Zenodo DOI** | -|:-------------------------------------------------------------------------:|:--------------------------------:|:----------------------------------:|:--------------------------------:| -| [![][docs-stable-img]][docs-stable-url] [![][docs-dev-img]][docs-dev-url] | [![CI][ci-img]][ci-url] | [![Codecov][codecov-img]][codecov-url] | [![DOI][zenodo-img]][zenodo-url] | +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://reactivebayes.github.io/ReactiveMP.jl/stable) +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://reactivebayes.github.io/ReactiveMP.jl/dev) +[![Build Status](https://github.com/reactivebayes/ReactiveMP.jl/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/reactivebayes/ReactiveMP.jl/actions) +[![Coverage](https://codecov.io/gh/reactivebayes/ReactiveMP.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/reactivebayes/ReactiveMP.jl) +[![Zenodo](https://zenodo.org/badge/DOI/10.5281/zenodo.8381133.svg)](https://zenodo.org/doi/10.5281/zenodo.5913616) -[docs-dev-img]: https://img.shields.io/badge/docs-dev-blue.svg -[docs-dev-url]: https://reactivebayes.github.io/ReactiveMP.jl/dev +# Overview -[docs-stable-img]: https://img.shields.io/badge/docs-stable-blue.svg -[docs-stable-url]: https://reactivebayes.github.io/ReactiveMP.jl/stable +`ReactiveMP.jl` is a Julia package that provides an efficient reactive message passing based Bayesian inference engine on a factor graph. The package is a part of the bigger and user-friendly ecosystem for automatic Bayesian inference called [RxInfer](https://github.com/reactivebayes/RxInfer.jl). While ReactiveMP.jl exports only the inference engine, RxInfer provides convenient tools for model and inference constraints specification as well as routines for running efficient inference both for static and dynamic datasets. -[ci-img]: https://github.com/reactivebayes/ReactiveMP.jl/actions/workflows/ci.yml/badge.svg?branch=main -[ci-url]: https://github.com/reactivebayes/ReactiveMP.jl/actions +ReactiveMP.jl is designed for advanced users who need fine-grained control over message passing, custom factor nodes, and custom update rules. It does not create a specific message passing schedule in advance, but rather _reacts_ on changes in the data source (hence _reactive_ in the name of the package). -[codecov-img]: https://codecov.io/gh/reactivebayes/ReactiveMP.jl/branch/main/graph/badge.svg -[codecov-url]: https://codecov.io/gh/reactivebayes/ReactiveMP.jl?branch=main +# Installation -[zenodo-img]: https://zenodo.org/badge/DOI/10.5281/zenodo.8381133.svg -[zenodo-url]: https://zenodo.org/doi/10.5281/zenodo.5913616 +Install ReactiveMP through the Julia package manager: -# Reactive message passing engine +```julia +] add ReactiveMP +``` -ReactiveMP.jl is a Julia package that provides an efficient reactive message passing based Bayesian inference engine on a factor graph. The package is a part of the bigger and user-friendly ecosystem for automatic Bayesian inference called [RxInfer](https://github.com/reactivebayes/RxInfer.jl). While ReactiveMP.jl exports only the inference engine, RxInfer provides convenient tools for model and inference constraints specification as well as routines for running efficient inference both for static and dynamic datasets. +Optionally, use `] test ReactiveMP` to validate the installation by running the test suite. -## Examples and tutorials +# Documentation -The ReactiveMP.jl package is intended for advanced users with a deep understanding of message passing principles. -Accesible tutorials and examples are available in the [RxInfer documentation](https://reactivebayes.github.io/RxInfer.jl/stable/). +For more information about `ReactiveMP.jl` please refer to the [documentation](https://reactivebayes.github.io/ReactiveMP.jl/stable). + +# Examples and tutorials + +The ReactiveMP.jl package is intended for advanced users with a deep understanding of message passing principles. Accessible tutorials and examples are available in the [RxInfer documentation](https://reactivebayes.github.io/RxInfer.jl/stable/). + +# Ecosystem + +The `RxInfer` framework consists of four *core* packages developed by [ReactiveBayes](https://github.com/reactivebayes/): + +- [`ReactiveMP.jl`](https://github.com/reactivebayes/ReactiveMP.jl) - the underlying message passing-based inference engine (this package) +- [`RxInfer.jl`](https://github.com/reactivebayes/RxInfer.jl) - user-friendly modeling and inference layer +- [`GraphPPL.jl`](https://github.com/reactivebayes/GraphPPL.jl) - model and constraints specification package +- [`ExponentialFamily.jl`](https://github.com/reactivebayes/ExponentialFamily.jl) - package for exponential family distributions +- [`Rocket.jl`](https://github.com/reactivebayes/Rocket.jl) - reactive extensions package for Julia + +# References + +- [A Julia package for reactive variational Bayesian inference](https://doi.org/10.1016/j.simpa.2022.100299) - a reference paper for the `ReactiveMP.jl` package. +- [Reactive Probabilistic Programming for Scalable Bayesian Inference](https://pure.tue.nl/ws/portalfiles/portal/313860204/20231219_Bagaev_hf.pdf) - a PhD dissertation outlining core ideas and principles behind ReactiveMP ([link2](https://research.tue.nl/nl/publications/reactive-probabilistic-programming-for-scalable-bayesian-inferenc), [link3](https://github.com/bvdmitri/phdthesis)). +- [Variational Message Passing and Local Constraint Manipulation in Factor Graphs](https://doi.org/10.3390/e23070807) - describes theoretical aspects of the underlying Bayesian inference method. +- [Reactive Message Passing for Scalable Bayesian Inference](https://doi.org/10.48550/arXiv.2112.13251) - describes implementation aspects of the Bayesian inference engine and performs benchmarks and accuracy comparison on various models. # License diff --git a/docs/make.jl b/docs/make.jl index ff4e5b60b..adabc611d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -12,11 +12,22 @@ makedocs( sitename = "ReactiveMP.jl", pages = [ "Introduction" => "index.md", + "Concepts" => [ + "Factor graphs" => "concepts/factor-graphs.md", + "Message passing" => "concepts/message-passing.md", + "Reactive programming" => "concepts/reactive-programming.md", + "Inference lifecycle" => "concepts/inference-lifecycle.md", + ], "Library" => [ "Factor nodes" => "lib/nodes.md", + "Variables" => "lib/variables.md", "Messages" => "lib/message.md", "Marginals" => "lib/marginal.md", "Message update rules" => "lib/rules.md", + "Callbacks" => "lib/callbacks.md", + "Stream postprocessors" => "lib/stream-postprocessors.md", + "Approximations" => "lib/approximations.md", + "Score functions" => "lib/score.md", "Helper utils" => "lib/helpers.md", "Algebra utils" => "lib/algebra.md", "Specific factor nodes" => [ @@ -31,9 +42,16 @@ makedocs( "MultinomialPolya" => "lib/nodes/multinomial_polya.md", ] ], + "Annotations" => [ + "Overview" => "lib/annotations.md", + "Log-scale" => "lib/annotations/logscale.md", + "Input arguments" => "lib/annotations/input_arguments.md", + ], "Custom functionality" => [ "Custom functional form" => "custom/custom-functional-form.md", - "Custom addons" => "custom/custom-addons.md" + ], + "Migration guides" => [ + "v5 to v6" => "migration-guides/v5-to-v6.md", ], "Extra" => [ "Contributing" => "extra/contributing.md", diff --git a/docs/src/assets/logo.svg b/docs/src/assets/logo.svg new file mode 100644 index 000000000..9ee24f914 --- /dev/null +++ b/docs/src/assets/logo.svg @@ -0,0 +1 @@ +ReactiveMP \ No newline at end of file diff --git a/docs/src/concepts/factor-graphs.md b/docs/src/concepts/factor-graphs.md new file mode 100644 index 000000000..71501dbf4 --- /dev/null +++ b/docs/src/concepts/factor-graphs.md @@ -0,0 +1,60 @@ +# [Factor graphs](@id concepts-factor-graphs) + +A **factor graph** is a graphical representation of how a joint probability distribution factorizes into a product of local functions. ReactiveMP.jl uses factor graphs as the underlying structure for all inference computations. + +## [Variables and factors](@id concepts-factor-graphs-variables-and-factors) + +A factor graph has two kinds of nodes: + +- **Variable nodes** — represent the random quantities in your model (latent variables, observed data, or constants). +- **Factor nodes** — represent the local functions (conditional distributions, likelihoods, deterministic transforms) that connect variables together. + +An edge between a factor node and a variable node means that the factor involves that variable. + +Consider a simple model with three variables `x`, `y`, and `z` and two factors `f` and `g`: + +``` + (x) ── [f] ── (y) ── [g] ── (z) +``` + +This graph represents the factorization: + +```math +p(x, y, z) = f(x, y) \cdot g(y, z) +``` + +Each factor is a *local* function: `f` only involves `x` and `y`, and `g` only involves `y` and `z`. Messages can therefore be computed locally at each factor, using only the information from neighboring nodes. + +## [Stochastic and deterministic factors](@id concepts-factor-graphs-node-types) + +ReactiveMP.jl distinguishes two kinds of factor nodes: + +- [`Stochastic`](@ref) nodes represent probability distributions, e.g. `p(x | μ, σ)`. They are used for likelihood terms, priors, and latent variable relationships. +- [`Deterministic`](@ref) nodes represent hard functional constraints, e.g. `z = x + y`. They do not add probability mass — they enforce an exact relationship. + +This distinction matters for how messages are computed and how the variational free energy objective is structured. See [`isdeterministic`](@ref) and [`isstochastic`](@ref). + +## [How ReactiveMP.jl represents factor nodes](@id concepts-factor-graphs-node-registration) + +Every factor in ReactiveMP.jl is a Julia type registered with the [`@node`](@ref) macro. The macro declares the node's name, its type (`Stochastic` or `Deterministic`), and the fixed set of edges (interfaces) it connects to: + +```julia +struct MyFactor end + +@node MyFactor Stochastic [ out, x, y ] +# ^^^^^^^^ ^^^^^^^^^^ ^^^^^^^^^^ +# tag type edges (first = output by convention) +``` + +After registration, `MyFactor` can be used as a factor node in a model. The inference engine then dispatches message update rules defined with [`@rule`](@ref) for that node type. + +!!! note + The `@node` macro only registers the factor's structure. Message update rules must be added separately using [`@rule`](@ref) and [`@marginalrule`](@ref). See [Message update rules](@ref lib-rules) for details. + +ReactiveMP.jl ships with [many predefined nodes](@ref lib-predefined-nodes) for common distributions and operations — Gaussian, Gamma, Beta, Bernoulli, arithmetic operations, and more. Custom nodes can be registered using the same `@node` macro. + +## [Next steps](@id concepts-factor-graphs-next) + +- [Variables](@ref lib-variables) — the three kinds of variable nodes and how they work. +- [Message passing](@ref concepts-message-passing) — how information flows through the graph. +- [Inference lifecycle](@ref concepts-inference-lifecycle) — the three phases of building and running inference. diff --git a/docs/src/concepts/inference-lifecycle.md b/docs/src/concepts/inference-lifecycle.md new file mode 100644 index 000000000..c55ef2b58 --- /dev/null +++ b/docs/src/concepts/inference-lifecycle.md @@ -0,0 +1,100 @@ +# [Inference lifecycle](@id concepts-inference-lifecycle) + +Every inference computation in ReactiveMP.jl goes through three phases: **construction**, **activation**, and **observation**. Understanding these phases is essential when working directly with the engine. + +!!! note + If you are using ReactiveMP.jl through [RxInfer.jl](https://github.com/reactivebayes/RxInfer.jl), these phases are managed for you automatically by the `infer` function. This page is aimed at users working with the low-level API directly. + +## [Phase 1: Construction](@id concepts-inference-lifecycle-construction) + +In the construction phase, you create the variables and factor nodes of your model and connect them together. + +**Variables** are created with one of three constructors depending on their role: + +```julia +x = randomvar() # latent variable — will be inferred +y = datavar() # observed quantity — will receive data +c = constvar(2.0) # fixed constant — never changes +``` + +See [Variables](@ref lib-variables) for a full description of each type. + +**Factor nodes** are connected to variables using the `make_node` machinery (typically called by a model specification layer). Each connection registers the variable with the node and allocates a [`ReactiveMP.MessageObservable`](@ref) stream for that edge. At this point, all streams are **lazy** — they exist as placeholders but are not yet computing anything. + +After construction, the graph looks like this conceptually: + +``` + [datavar: y] ──── [factor: f] ──── (randomvar: x) + unconnected unconnected + streams streams +``` + +!!! note + The degree of a variable (number of connected factors) is determined during construction. Adding connections after activation is not supported. + +## [Phase 2: Activation](@id concepts-inference-lifecycle-activation) + +Activation wires the lazy observable streams into a live reactive network. This is done by calling [`ReactiveMP.activate!`](@ref) on each variable and factor node, passing an options object that bundles inference-time configuration. + +For factor nodes, activation is driven by [`ReactiveMP.FactorNodeActivationOptions`](@ref), which carries: +- The factorization assumption (mean-field, structured, or full BP). +- An optional [stream postprocessor](@ref lib-stream-postprocessors) applied to outbound message, marginal, and score streams (e.g. for scheduling). +- Metadata and approximation method settings. + +For variables, activation is driven by [`ReactiveMP.RandomVariableActivationOptions`](@ref) or [`ReactiveMP.DataVariableActivationOptions`](@ref), which wire up the marginal stream and prediction stream. + +After activation, the graph is live: + +``` + [datavar: y] ──── [factor: f] ──── (randomvar: x) ──► marginal q(x) + ▲ rules streams + (waiting for connected connected + observations) +``` + +Every edge now carries a [`ReactiveMP.MessageObservable`](@ref) that is subscribed to its upstream sources. The marginal at `x` is connected to a [`ReactiveMP.MarginalObservable`](@ref) that will emit updated beliefs every time a message changes. + +## [Phase 3: Observation](@id concepts-inference-lifecycle-observation) + +Once the graph is activated, inference is driven by feeding data into the data variables using [`new_observation!`](@ref): + +```julia +new_observation!(y, 3.14) +``` + +This call pushes a new [`Message`](@ref) wrapping a `PointMass(3.14)` into the data variable's outbound stream. The change propagates reactively through all connected factor nodes, triggering rule computations, which in turn push updated messages to downstream variables, which update their marginals. + +The result is that subscribing to the marginal stream of `x` yields updated posterior beliefs automatically: + +``` + new_observation!(y, 3.14) + │ + ▼ + [datavar: y] ──► message ──► [factor: f] ──► message ──► (randomvar: x) + │ + ▼ + marginal q(x) emits +``` + +You can subscribe to the marginal stream of any [`RandomVariable`](@ref) to receive updated beliefs: + +```julia +subscribe!(get_stream_of_marginals(x), (marginal) -> println("Updated: ", mean(marginal))) +``` + +Multiple calls to [`new_observation!`](@ref) are possible after activation — each one triggers another round of reactive propagation. This makes the engine suitable for streaming/online inference scenarios. + +## [Summary](@id concepts-inference-lifecycle-summary) + +| Phase | What happens | Key functions | +|-------|-------------|---------------| +| **Construction** | Variables and nodes created, edges connected, streams allocated (lazy) | [`randomvar`](@ref), [`datavar`](@ref), [`constvar`](@ref), [`@node`](@ref) | +| **Activation** | Lazy streams wired into a live reactive network | [`ReactiveMP.activate!`](@ref), [`ReactiveMP.FactorNodeActivationOptions`](@ref) | +| **Observation** | Data fed in, messages propagate, marginals update | [`new_observation!`](@ref), [`ReactiveMP.get_stream_of_marginals`](@ref) | + +## [Next steps](@id concepts-inference-lifecycle-next) + +- [Factor nodes](@ref lib-node) — how nodes are implemented and activated. +- [Variables](@ref lib-variables) — stream creation and activation details for each variable type. +- [Callbacks](@ref lib-callbacks) — how to hook into message and marginal computation events. +- [Custom functional form](@ref custom-functional-form) — constraining the functional form of marginals during inference. diff --git a/docs/src/concepts/message-passing.md b/docs/src/concepts/message-passing.md new file mode 100644 index 000000000..aee20795c --- /dev/null +++ b/docs/src/concepts/message-passing.md @@ -0,0 +1,77 @@ +# [Message passing](@id concepts-message-passing) + +Message passing is the algorithm that ReactiveMP.jl uses to perform inference on a [factor graph](@ref concepts-factor-graphs). Instead of computing the full joint distribution, each factor node and each variable node exchange small, local summaries — called **messages** — with their neighbors. The posterior beliefs emerge from combining these messages. + +## [Belief propagation](@id concepts-message-passing-bp) + +**Belief propagation** (also known as the sum-product algorithm) computes *exact* marginal posteriors on tree-shaped graphs. The key idea is that a message from a factor node `f` toward a variable `x` summarizes everything `f` knows about `x` from the rest of the graph: + +```math +\mu_{f \to x}(x) = \int f(x, y, z) \; \mu_{y \to f}(y) \; \mu_{z \to f}(z) \; \mathrm{d}y \; \mathrm{d}z +``` + +The message from `x` back toward `f` collects the beliefs arriving at `x` from all *other* connected factors. The marginal `q(x)` is then the product of all incoming messages at `x`. + +On graphs with cycles, this same procedure is run iteratively (loopy belief propagation) and typically converges to a good approximation. + +## [Variational message passing](@id concepts-message-passing-vmp) + +**Variational message passing** (VMP) is a generalization that performs approximate inference by minimizing the Bethe free energy — a variational objective — rather than computing exact integrals. ReactiveMP.jl implements VMP as the primary inference algorithm because: + +1. It includes exact belief propagation as a special case (no factorization constraints = exact BP). +2. It handles non-conjugate and complex models via the **mean-field** or **structured factorization** assumptions. +3. It admits a local, message-level implementation that fits the reactive computation model naturally. + +Under a mean-field factorization assumption `q(x, y) = q(x) q(y)`, the VMP message from factor `f` toward variable `x` becomes: + +```math +\mu_{f \to x}(x) = \exp \int q(y) \, q(z) \log f(x, y, z) \; \mathrm{d}y \; \mathrm{d}z +``` + +Notice that this uses *marginals* `q(y)` and `q(z)` rather than messages `μ(y)` and `μ(z)`. ReactiveMP.jl tracks this distinction through its [functional dependencies](@ref lib-node-functional-dependencies) policy. + +For a deeper treatment of the theory, see the [PhD dissertation](https://pure.tue.nl/ws/portalfiles/portal/313860204/20231219_Bagaev_hf.pdf) that ReactiveMP.jl is based on. + +## [How ReactiveMP.jl chooses the algorithm](@id concepts-message-passing-dispatch) + +ReactiveMP.jl does not ask you to pick an algorithm up front. Instead, the correct message update rule is selected automatically based on: + +1. **The node type** ([`Stochastic`](@ref) or [`Deterministic`](@ref)) — deterministic nodes always use BP-style messages. +2. **The factorization assumption** attached to the model — mean-field or structured factorization triggers the appropriate VMP rule. +3. **Julia's multiple dispatch** — `@rule` definitions are dispatched on the node type, the outgoing edge, and the types of incoming messages/marginals. + +This means adding a new factorization assumption automatically routes computation to the right rules without changing any node code. + +## [The reactive computation model](@id concepts-message-passing-reactive) + +The word *reactive* in the package name refers to how messages are scheduled. Many message passing libraries build an explicit computation schedule (e.g., forward-backward passes) before inference starts. ReactiveMP.jl takes a different approach: **there is no pre-built schedule**. Instead: + +- Each variable and factor node holds a *reactive stream* (a [`ReactiveMP.MessageObservable`](@ref) or [`ReactiveMP.MarginalObservable`](@ref)) that emits updated values whenever its inputs change. +- When new data arrives via [`new_observation!`](@ref), the change propagates automatically through the graph, triggering only the rules that depend on the updated value. +- The propagation order is determined by the graph structure at runtime, not a static plan. + +## [The reactive computation model](@id concepts-message-passing-reactive) + +The word *reactive* in the package name refers to how messages are scheduled. Many message passing libraries build an explicit computation schedule (e.g., forward-backward passes) before inference starts. ReactiveMP.jl takes a different approach: **there is no pre-built schedule**. Instead: + +- Each variable and factor node holds a *reactive stream* (a [`ReactiveMP.MessageObservable`](@ref) or [`ReactiveMP.MarginalObservable`](@ref)) that emits updated values whenever its inputs change. +- When new data arrives via [`new_observation!`](@ref), the change propagates automatically through the graph, triggering only the rules that depend on the updated value. +- The propagation order is determined by the graph structure at runtime, not a static plan. + +This reactive design is built on top of [Rocket.jl](https://github.com/ReactiveBayes/Rocket.jl), a Julia library for reactive programming with observables. For a higher-level explanation of this paradigm and how to conceptualize messages as streams, see the [Reactive Programming Model](@ref concepts-reactive-programming). Understanding that messages are *streams* rather than *values* helps explain the activation step described in [Inference lifecycle](@ref concepts-inference-lifecycle). + +## [Messages and marginals](@id concepts-message-passing-types) + +ReactiveMP.jl uses two distinct wrapper types: + +- [`Message`](@ref) — a message flowing along a single edge, from a factor toward a variable (or vice versa). +- [`Marginal`](@ref) — the posterior belief at a variable, computed as the normalized product of all incoming messages. + +Both are thin wrappers around a probability distribution object. The separation allows the engine to track metadata such as whether a value is clamped (fixed) or initial (a prior seed), and to carry optional annotations for model evidence computation (see [Annotations](@ref lib-annotations)). + +## [Next steps](@id concepts-message-passing-next) + +- [Messages](@ref lib-message) — detailed description of the `Message` type and message observables. +- [Marginals](@ref lib-marginal) — the `Marginal` type and marginal observables. +- [Message update rules](@ref lib-rules) — how to define and query rules with `@rule` and `@marginalrule`. +- [Inference lifecycle](@ref concepts-inference-lifecycle) — the three phases of construction, activation, and observation. diff --git a/docs/src/concepts/reactive-programming.md b/docs/src/concepts/reactive-programming.md new file mode 100644 index 000000000..b2cd4e5b9 --- /dev/null +++ b/docs/src/concepts/reactive-programming.md @@ -0,0 +1,25 @@ +# [Reactive Programming Model](@id concepts-reactive-programming) + +ReactiveMP.jl is built on a **reactive programming** paradigm. Unlike traditional inference engines that follow a pre-defined, static computation schedule (e.g., performing a forward and backward pass), ReactiveMP.jl operates by reacting to changes in the underlying data. + +## The Mental Model + +To use this package effectively, it helps to shift your thinking from "static values" to "dynamic streams." + +### 1. Observables as Streams +In many algorithms, a message is just a piece of data at a specific point in time. In ReactiveMP.jl, messages and marginals are treated as **Observables**. + +Think of an Observable not as a single number, but as a **stream** of values. Whenever a node performs a computation and produces a new result, it "emits" this value into the stream. Any downstream node listening to this stream will automatically receive the update. + +### 2. Automatic Propagation +Because nodes are connected via these streams, the graph handles its own execution. You do not need to manually trigger a "message passing step." Instead: +- An external event occurs (e.g., a new observation is added via [`new_observation!`](@ref)). +- This change triggers an update in a specific node. +- That node's output changes, which automatically notifies all connected neighbors. +- The change propagates through the graph structure, only visiting nodes that are actually affected by the update. + +This "dependency-driven" execution ensures that we only perform the minimum amount of computation necessary to keep the beliefs up to date. + +## For Deep Dives + +The underlying machinery for this reactive behavior is provided by [Rocket.jl](https://github.com/ReactiveBayes/Rocket.jl). If you want to understand the low-level mechanics of how Observables, triggers, and reactive streams are implemented in Julia, we highly recommend exploring the Rocket.jl documentation. diff --git a/docs/src/custom/custom-addons.md b/docs/src/custom/custom-addons.md deleted file mode 100644 index 910e96f01..000000000 --- a/docs/src/custom/custom-addons.md +++ /dev/null @@ -1,70 +0,0 @@ -# [Custom Addons](@id custom-addons) - -Standard message passing schemes only pass along distributions to other nodes. However, for more advanced usage, there might be a need for passing along additional information in messages and/or marginals. One can for example think of passing along the scaling of the distribution or some information that specifies how the message or marginal was computed, i.e. which messages were used for its computation and which node was preceding it. Another use cases is saving extra debugging information inside messages themselves, e.g. what arguments have been used to compute a message. - -Addons provide a solution here. Basically, addons are structures that contain extra information that are passed along the graph with messages and marginals in a tuple. These addons can be extracted using the `getaddons(message/marginal)` function. Its usage and operations can differ significantly for each application, yet below gives a concise overview on how to implement them on your own. - -## Example - -Suppose that we wish to create an addon that counts the number of computations that preceded some message or marginal. This addon can be created by adding the file `src/addons/count.jl` and by including it in the `ReactiveMP.jl` file. - -### Step 1: Creating the addon structure - -Let's start by defining our new addon structure. This might seem daunting, but basically only requires us to specify the information that we would like to collect. Just make sure that it is specified as a subtype of `AbstractAddon`. In our example this becomes: - -```julia -struct AddonCount{T} <: AbstractAddon - count :: T -end -``` - -You can add additional fields or functions for improved handling, such as `get_count()` or `show()` functions. - -### Step 2: Compute addon value after computing a message - -As a second step we need to specify how the addon behaves when a new message is computed in a factor node. -For this purpose we need to implement a specialized version of the `message_mapping_addon()` function. This function accepts the mapping variables of the factor node and updates the addons by extending the tuple. - -In our example we could write -```julia -# This specification assumes that the default value for addon is `AddonCount(nothing)` -function message_mapping_addon(::AddonCount{Nothing}, mapping, messages, marginals, result, addons) - - # get number of operations of messages - message_count = 0 - for message in messages - message_count += getcount(message) - end - - # get number of operations of marginals - marginal_count = 0 - for marginal in marginals - marginal_count += getcount(marginal) - end - - # extend addons with AddonCount() structure - return AddonCount(message_count + marginal_count + 1) -end - -``` -### Step 3: Computing products - -The goal is to update the `AddonCount` structure when we multiply 2 messages. As a result, we need to write a function that allows us to define this behaviour. This function is called `multiply_addons` and accepts 5 arguments. In our example this becomes - -```julia -function multiply_addons(left_addon::AddonCount, right_addon::AddonCount, new_dist, left_dist, right_dist) - return AddonCount(left_addon.count + right_addon.count + 1) -end -``` - -here we add the number of operations from the addons that are being multiplied and we add one (for the current operation). we are aware that this is likely not valid for iterative message passing schemes, but it still serves as a nice example. the `left_addon` and `right_addon` argument specify the `addoncount` objects that are being multiplied. corresponding to these addons, there are the distributions `left_dist` and `right_dist`, which might contain information for computing the product. the new distribution `new_dist ∝ left_dist * right_dist` is also passed along for potentially reusing the result of earlier computations. - -### More information - -For more advanced information check the implementation of the log-scale or memory addons. - -### Built-in addons - -```@docs -ReactiveMP.AddonDebug -``` diff --git a/docs/src/extra/extensions.md b/docs/src/extra/extensions.md index 575326c91..8f9bbfb6b 100644 --- a/docs/src/extra/extensions.md +++ b/docs/src/extra/extensions.md @@ -1,14 +1,54 @@ -# Extensions and interaction with the Julia ecosystem +# [Extensions and ecosystem integration](@id extra-extensions) -`ReactiveMP.jl` exports extra functionality if other Julia packages are loaded in the same environment. +`ReactiveMP.jl` activates extra functionality when other Julia packages are loaded alongside it. These are implemented as Julia [package extensions](https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-dependencies) (weak dependencies) and require no additional configuration — simply `using` the relevant package is enough. -## Optimisers.jl +## [Optimisers.jl](@id extra-extensions-optimisers) -The [`Optimizers.jl`](https://github.com/FluxML/Optimisers.jl) package defines many standard gradient-based optimisation rules, and tools for applying them to deeply nested models. -The optimizers defined in the `Optimziers.jl` are compatible with the CVI approximation method. +**Package:** [`Optimisers.jl`](https://github.com/FluxML/Optimisers.jl) -## DiffResults.jl (loaded automatically with the `ForwardDiff.jl`) +**What it provides:** Gradient-based optimizers (Adam, ADAM, NADAM, RMSProp, etc.) compatible with the [`CVI`](@ref) approximation method. -The [`DiffResults.jl`](https://github.com/JuliaDiff/DiffResults.jl) provides the `DiffResult` type, which can be passed to in-place differentiation methods instead of an output buffer. -If loaded in the current Julia session enables faster derivatives with the `ForwardDiffGrad` option in the CVI approximation method (in the `Gaussian` case). +```julia +using ReactiveMP, Optimisers +meta = DeltaMeta(method = CVI( + rng, + n_samples = 100, + n_iterations = 50, + opt = Optimisers.Adam(0.01), # ← any Optimisers.jl rule +)) + +y ~ f(x) where { meta = meta } +``` + +**How it works internally:** The extension implements `ReactiveMP.cvi_setup` and `ReactiveMP.cvi_update!` for `Optimisers.AbstractRule`, delegating to `Optimisers.init` and `Optimisers.apply!`. This maps the Optimisers.jl stateful optimizer API onto the CVI update loop. + +## [DiffResults.jl](@id extra-extensions-diffresults) + +**Package:** [`DiffResults.jl`](https://github.com/JuliaDiff/DiffResults.jl) — loaded automatically when `ForwardDiff.jl` is present. + +**What it provides:** Faster derivative computation for the [`ForwardDiffGrad`](@ref) gradient estimator inside CVI, in the special case where all inputs are Gaussian distributions. + +When `DiffResults` is available, `ForwardDiffGrad` uses `DiffResults.DiffResult` as an output buffer for in-place differentiation, avoiding redundant forward passes. This can meaningfully reduce the per-iteration cost of CVI in purely Gaussian models. + +No explicit configuration is needed — the extension activates automatically whenever `ForwardDiff` (and transitively `DiffResults`) is loaded into the session. + +## [ExponentialFamilyProjection.jl](@id extra-extensions-projection) + +**Package:** [`ExponentialFamilyProjection.jl`](https://github.com/ReactiveBayes/ExponentialFamilyProjection.jl) + +**What it provides:** Enables [`CVIProjection`](@ref) for use inside [Delta nodes](@ref lib-nodes-delta). + +[`CVIProjection`](@ref) extends CVI by projecting the resulting approximate message onto the nearest member of a target exponential family. This projection step requires `ExponentialFamilyProjection.jl`. Without it, placing `CVIProjection` in a delta node raises an informative error. + +```julia +using ReactiveMP, ExponentialFamilyProjection + +meta = DeltaMeta(method = CVIProjection( + sampling_strategy = FullSampling(100), +)) + +y ~ f(x) where { meta = meta } +``` + +The extension also defines `prod` rules on `DivisionOf` objects needed for the backward message computation in the delta node, and registers `CVIProjection` as compatible with the delta node approximation framework. diff --git a/docs/src/index.md b/docs/src/index.md index b84ac0227..6848b8484 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -3,27 +3,42 @@ ReactiveMP.jl *Julia package for reactive message passing Bayesian inference engine on a factor graph.* +`ReactiveMP.jl` is a low-level inference engine that implements variational message passing on factor graphs. It is designed for advanced users who need fine-grained control over message passing, custom factor nodes, and custom update rules. For most use cases, the [RxInfer.jl](https://github.com/reactivebayes/RxInfer.jl) package provides a convenient model specification layer on top of ReactiveMP.jl. + !!! note - This package exports only an inference engine, for the full ecosystem with convenient model and constraints specification we refer user to the [`RxInfer.jl`](https://github.com/reactivebayes/RxInfer.jl) package and its [documentation](https://reactivebayes.github.io/RxInfer.jl/stable/). + This package exports only an inference engine. For the full ecosystem with convenient model and constraints specification, see [`RxInfer.jl`](https://github.com/reactivebayes/RxInfer.jl) and its [documentation](https://reactivebayes.github.io/RxInfer.jl/stable/). + +## [Start here](@id index-start-here) + +If you are new to ReactiveMP.jl, read the Concepts section first. It explains the key ideas without assuming prior familiarity with the codebase: -## Ideas and principles behind `ReactiveMP.jl` +1. **[Factor graphs](@ref concepts-factor-graphs)** — what factor graphs are and how ReactiveMP.jl represents them. +2. **[Message passing](@ref concepts-message-passing)** — how belief propagation and variational message passing work, and the reactive computation model. +3. **[Inference lifecycle](@ref concepts-inference-lifecycle)** — the three phases every inference run goes through: construction, activation, and observation. -`ReactiveMP.jl` is a particular implementation of message passing on factor graphs, which does not create any specific message passing schedule in advance, but rather _reacts_ on changes in the data source (hence _reactive_ in the name of the package). The detailed explanation of the ideas and principles behind the _Reactive Message Passing_ can be found in PhD disseration of _Dmitry Bagaev_ titled [__Reactive Probabilistic Programming for Scalable Bayesian Inference__](https://pure.tue.nl/ws/portalfiles/portal/313860204/20231219_Bagaev_hf.pdf) ([link2](https://research.tue.nl/nl/publications/reactive-probabilistic-programming-for-scalable-bayesian-inferenc), [link3](https://github.com/bvdmitri/phdthesis)). +After reading the Concepts section, the Library section provides the full API reference for each component. -## Examples and tutorials +## [Ideas and principles behind `ReactiveMP.jl`](@id index-ideas) + +`ReactiveMP.jl` is a particular implementation of message passing on factor graphs, which does not create any specific message passing schedule in advance, but rather _reacts_ on changes in the data source (hence _reactive_ in the name of the package). The detailed explanation of the ideas and principles behind the _Reactive Message Passing_ can be found in PhD dissertation of _Dmitry Bagaev_ titled [__Reactive Probabilistic Programming for Scalable Bayesian Inference__](https://pure.tue.nl/ws/portalfiles/portal/313860204/20231219_Bagaev_hf.pdf) ([link2](https://research.tue.nl/nl/publications/reactive-probabilistic-programming-for-scalable-bayesian-inferenc), [link3](https://github.com/bvdmitri/phdthesis)). + +## [Examples and tutorials](@id index-examples) The `ReactiveMP.jl` package is intended for advanced users with a deep understanding of message passing principles. -Accesible tutorials and examples are available in the [RxInfer documentation](https://reactivebayes.github.io/RxInfer.jl/stable/). +Accessible tutorials and examples are available in the [RxInfer documentation](https://reactivebayes.github.io/RxInfer.jl/stable/). ## Table of Contents ```@contents Pages = [ + "concepts/factor-graphs.md", + "concepts/message-passing.md", + "concepts/inference-lifecycle.md", "lib/nodes.md", + "lib/variables.md", "lib/message.md", "lib/marginal.md", "lib/rules.md", - "lib/nodes.md", "lib/helpers.md", "lib/algebra.md", "extra/contributing.md", diff --git a/docs/src/lib/algebra.md b/docs/src/lib/algebra.md index fd22625d8..45ec3cc23 100644 --- a/docs/src/lib/algebra.md +++ b/docs/src/lib/algebra.md @@ -1,20 +1,87 @@ -# [Algebra common utilities](@id lib-helpers-algebra-common) +# [Algebra utilities](@id lib-helpers-algebra-common) -## [diageye](@id lib-helpers-algebra-diageye) +This page documents linear-algebra building blocks used internally by ReactiveMP.jl's built-in message update rules. They are exposed publicly so that custom rules can reuse them without reimplementing common operations. + +## [Matrix constructors](@id lib-helpers-algebra-matrices) + +### `diageye` — identity matrix + +[`diageye`](@ref) is a convenience alias for constructing a dense identity matrix of a given size. It is equivalent to `Matrix{Float64}(I, n, n)` but reads more clearly in rule code: + +```jldoctest +julia> using ReactiveMP; diageye(3) +3×3 Matrix{Float64}: + 1.0 0.0 0.0 + 0.0 1.0 0.0 + 0.0 0.0 1.0 +``` + +### `CompanionMatrix` — AR coefficient matrix + +[`CompanionMatrix`](@ref) represents the companion form of an AR(p) coefficient vector `θ`: + +```math +C(\theta) = \begin{bmatrix} \theta_1 & \theta_2 & \cdots & \theta_p \\ 1 & 0 & \cdots & 0 \\ 0 & 1 & \cdots & 0 \\ \vdots & & \ddots & \vdots \\ 0 & 0 & \cdots & 0 \end{bmatrix} +``` + +It is a lazy `AbstractMatrix` — no allocation is made until an element is accessed or a product is computed. Specialized `*` methods exploit the sparse structure for efficient matrix-vector and matrix-matrix products. + +This matrix appears in the [Autoregressive node](@ref lib-nodes-ar) and [Continuous transition node](@ref lib-nodes-ctransition) when expressing a higher-order AR process as a first-order state-space model. + +### `PermutationMatrix` — structured permutation + +[`PermutationMatrix`](@ref) represents an `n×n` permutation matrix as a length-`n` index vector rather than a dense matrix. It supports efficient multiplication with vectors and matrices via specialized `mul!` dispatch, and its inverse is simply its adjoint. + +Permutation matrices appear in normalizing flow layers ([`PermutationLayer`](@ref)) to shuffle input dimensions between coupling layers. + +### `StandardBasisVector` — sparse one-hot vector + +[`StandardBasisVector`](@ref) represents a standard Cartesian basis vector — all zeros except one element — without allocating a dense array. It supports dot products, outer products, and matrix multiplication, all using the sparse structure. ```@docs diageye ReactiveMP.CompanionMatrix ReactiveMP.PermutationMatrix ReactiveMP.StandardBasisVector -ReactiveMP.GammaShapeLikelihood -ReactiveMP.ImportanceSamplingApproximation +``` + +## [In-place scalar and array operations](@id lib-helpers-algebra-inplace) + +These functions return `alpha * A` or `-A`, **reusing the storage of `A` when the type allows it** (i.e. when `A` is a mutable `Array`). For immutable arrays or scalars they fall back to a regular allocation. This makes them safe to use in rule code regardless of whether the input is mutable. + +```@docs ReactiveMP.mul_inplace! ReactiveMP.negate_inplace! +``` + +## [Trace and rank-1 update utilities](@id lib-helpers-algebra-trace) + +These functions implement common linear-algebra patterns that appear repeatedly in Gaussian message rules. + +### `mul_trace` — allocation-free `tr(A·B)` + +[`ReactiveMP.mul_trace`](@ref) computes `tr(A * B)` directly without forming the full product matrix, saving an `O(n²)` allocation for square matrices. + +### `rank1update` — `A + x·yᵀ` via BLAS + +[`ReactiveMP.rank1update`](@ref) computes `A + x * y'`. For `BlasFloat` element types it dispatches to the BLAS `ger!` routine, which is highly optimized for this pattern. + +### `v_a_vT` — `v·a·vᵀ` + +[`ReactiveMP.v_a_vT`](@ref) computes `v * a * v'` or `v₁ * a * v₂'`. When `a` is a scalar it avoids forming a temporary matrix. Specialized methods exist for [`StandardBasisVector`](@ref) inputs that exploit the one-hot structure. + +```@docs ReactiveMP.mul_trace ReactiveMP.rank1update ReactiveMP.v_a_vT +``` + +## [Other utilities](@id lib-helpers-algebra-other) + +```@docs +ReactiveMP.GammaShapeLikelihood +ReactiveMP.ImportanceSamplingApproximation ReactiveMP.powerset ReactiveMP.besselmod ReactiveMP.isonehot -``` \ No newline at end of file +``` diff --git a/docs/src/lib/annotations.md b/docs/src/lib/annotations.md new file mode 100644 index 000000000..3691549d2 --- /dev/null +++ b/docs/src/lib/annotations.md @@ -0,0 +1,73 @@ +# [Annotations](@id lib-annotations) + +Messages and marginals in ReactiveMP carry a probability distribution as their primary content. Annotations are an optional side-channel that can travel alongside a message, holding arbitrary extra information keyed by `Symbol`. Typical uses include tracking log-scale factors (see [`LogScaleAnnotations`](@ref lib-annotations-logscale)), recording which messages were used to compute a result, or attaching debugging information. + +Annotations are designed to be zero-cost when unused: the underlying dictionary is only allocated on the first write. + +## AnnotationDict + +Every message and marginal holds an [`ReactiveMP.AnnotationDict`](@ref). The basic operations are: + +```@docs +ReactiveMP.AnnotationDict +ReactiveMP.annotate! +ReactiveMP.get_annotation +ReactiveMP.has_annotation +``` + +## Annotation processors + +Annotation processors are subtypes of [`ReactiveMP.AbstractAnnotations`](@ref) that define *how* annotations are written and merged. There are three integration points: + +- **Before a rule executes** — [`ReactiveMP.pre_rule_annotations!`](@ref) is called with the processor, the rule's `AnnotationDict`, the `MessageMapping`, the incoming messages and marginals. Use this to write annotations that does not depend on what the rule computed. +- **After a rule executes** — [`ReactiveMP.post_rule_annotations!`](@ref) is called with the processor, the rule's `AnnotationDict`, the `MessageMapping`, the incoming messages and marginals, and the result distribution. Use this to write annotations that depend on what the rule computed. +- **During a message product** — [`ReactiveMP.post_product_annotations!`](@ref) is called with the processor, a fresh merged `AnnotationDict`, and the left and right annotation dicts together with the distributions involved. Use this to merge annotations from the two incoming messages into the product message. + +```@docs +ReactiveMP.AbstractAnnotations +ReactiveMP.pre_rule_annotations! +ReactiveMP.post_rule_annotations! +ReactiveMP.post_product_annotations! +``` + +## Implementing a custom annotation processor + +To add a new kind of annotation, subtype `AbstractAnnotations` and implement the two callbacks: + +```julia +using ReactiveMP + +# not exported by default +import ReactiveMP: AbstractAnnotations, AnnotationDict, has_annotation, get_annotation, annotate! + +struct CountAnnotations <: AbstractAnnotations end + +# Called before each rule execution +function ReactiveMP.pre_rule_annotations!(::CountAnnotations, ann::AnnotationDict, mapping, messages, marginals) + return nothing +end + +# Called after each rule execution +function ReactiveMP.post_rule_annotations!(::CountAnnotations, ann::AnnotationDict, mapping, messages, marginals, result) + prev = has_annotation(ann, :count) ? get_annotation(ann, Int, :count) : 0 + annotate!(ann, :count, prev + 1) + return nothing +end + +# Called when two messages are multiplied +function ReactiveMP.post_product_annotations!(::CountAnnotations, merged::AnnotationDict, left_ann::AnnotationDict, right_ann::AnnotationDict, new_dist, left_dist, right_dist) + left_count = has_annotation(left_ann, :count) ? get_annotation(left_ann, Int, :count) : 0 + right_count = has_annotation(right_ann, :count) ? get_annotation(right_ann, Int, :count) : 0 + annotate!(merged, :count, left_count + right_count) + return nothing +end +``` + +Processors are passed to `FactorNodeActivationOptions` (for rule-time annotation) and [`ReactiveMP.MessageProductContext`](@ref) (for product-time merging) when building a model. Both sites must be configured — see the RxInfer documentation for how to set this up at the model level. + +## Built-in annotation processors + +```@contents +Pages = ["annotations/logscale.md", "annotations/input_arguments.md"] +Depth = 1 +``` diff --git a/docs/src/lib/annotations/input_arguments.md b/docs/src/lib/annotations/input_arguments.md new file mode 100644 index 000000000..7b13ce9d3 --- /dev/null +++ b/docs/src/lib/annotations/input_arguments.md @@ -0,0 +1,51 @@ +# [Input arguments annotations](@id lib-annotations-input-arguments) + +## Background: tracing rule inputs + +During inference, every message flowing along an edge is computed by a message +update rule. `InputArgumentsAnnotations` records what went into each rule call — +the `MessageMapping` (which node and interface the rule was for), the incoming +messages, the incoming marginals, and the result distribution — and propagates +that record through subsequent message products. + +This is useful for debugging and for implementing callbacks that need to inspect +the full provenance of a message: rather than re-running or re-examining the +model structure, the record travels with the message itself. + +## What gets stored + +After each rule execution a [`RuleInputArgumentsRecord`](@ref) is written into +the message's annotation dict under the `:rule_input_arguments` key. When two +messages are multiplied, their records are merged into a +[`ProductInputArgumentsRecord`](@ref) that contains all contributing records as +a flat list, regardless of how deeply nested the products were. + +## Reading input arguments from a message + +```julia +using ReactiveMP + +# ann is the AnnotationDict of some message +record = get_rule_input_arguments(ann) + +if record isa RuleInputArgumentsRecord + println("single rule: ", record.mapping) + println("messages: ", record.messages) + println("marginals: ", record.marginals) + println("result: ", record.result) +elseif record isa ProductInputArgumentsRecord + for r in record.mappings + println("contributed rule: ", r.mapping) + end +end +``` + +## API + +```@docs +ReactiveMP.InputArgumentsAnnotations +ReactiveMP.RuleInputArgumentsRecord +ReactiveMP.ProductInputArgumentsRecord +ReactiveMP.get_rule_input_arguments +ReactiveMP.AddonMemory +``` diff --git a/docs/src/lib/annotations/logscale.md b/docs/src/lib/annotations/logscale.md new file mode 100644 index 000000000..6a4de74a6 --- /dev/null +++ b/docs/src/lib/annotations/logscale.md @@ -0,0 +1,82 @@ +# [Log-scale annotations](@id lib-annotations-logscale) + +## Background: scale factors in message passing + +In sum-product message passing on a Forney-style factor graph, a message ``\vec{\mu}_{s_j}(s_j)`` flowing along an edge is in general *unnormalised*. It can be decomposed as + +```math +\vec{\mu}_{s_j}(s_j) = \beta_{s_j} \cdot \hat{p}_{s_j}(s_j), +``` + +where ``\hat{p}_{s_j}(s_j)`` is the normalised probability distribution (what ReactiveMP stores as the message's data) and ``\beta_{s_j}`` is the **scale factor** — a positive scalar that carries the accumulated normalisation constant of the message. + +A key result from [van Erp et al. (2023)](https://arxiv.org/abs/2306.05965) is that in an acyclic graph the product of two colliding messages integrates to exactly the model evidence: + +```math +\int \vec{\mu}_{s_j}(s_j)\, \overleftarrow{\mu}_{s_j}(s_j)\, \mathrm{d}s_j = p(y = \hat{y}). +``` + +This means the model evidence can be read off locally at *any* edge by tracking the scale factor. This enables **Bayesian model comparison** — averaging, selection, and combination — to be performed automatically as part of the same message-passing run that computes posteriors, without any separate evidence computation. + +## Log-scale factors + +For numerical stability, ReactiveMP tracks the *logarithm* of the scale factor, ``\log \beta``. When two messages are multiplied to form a product message, the log-scale of the result is: + +```math +\log \beta_\text{new} = \log \beta_\text{left} + \log \beta_\text{right} + \texttt{compute\_logscale}(\hat{p}_\text{new},\, \hat{p}_\text{left},\, \hat{p}_\text{right}), +``` + +where ``\texttt{compute\_logscale}(\hat{p}_\text{new}, \hat{p}_\text{left}, \hat{p}_\text{right})`` is the log of the normalisation constant of the product: + +```math +\texttt{compute\_logscale} = \log Z = \log \int \hat{p}_\text{left}(x)\, \hat{p}_\text{right}(x)\, \mathrm{d}x. +``` + +This function is defined in [`BayesBase.jl`](https://github.com/ReactiveBayes/BayesBase.jl) and extended for specific distribution families in [`ExponentialFamily.jl`](https://github.com/ReactiveBayes/ExponentialFamily.jl). + +## When is the log-scale zero? + +The log-scale ``\log \beta`` is zero exactly when the rule's raw factor product already integrates to 1 — i.e. the rule only reparameterises its inputs without dividing out a normalisation constant. This is the case for most **conjugate continuous rules**, for example a Normal message rule that takes a Normal prior and a PointMass precision and returns a new Normal: the computation is a direct parameter transformation and nothing is normalised away. + +The log-scale is **non-zero** whenever the rule involves a factor that does not integrate to 1 over the variable being messaged to. Two important classes: + +- **Discrete observed nodes** — for example, a Bernoulli node with observed output ``y = 1`` and unknown ``p``. The raw factor is ``f(p) = p``, which integrates to ``\tfrac{1}{2}`` over ``[0,1]``. The normalised message representation is ``\mathrm{Beta}(2, 1)``, but the lost constant ``\tfrac{1}{2}`` must be recorded: ``\log \beta = -\log 2``. + +- **Mixture and categorical nodes** — the backward message toward the model selection variable ``m`` contains the model evidence of each component as its scale factor (see equation 42 in van Erp et al. (2023)). This is the mechanism that makes automated Bayesian model comparison possible. + +Critically, **the normalised output distribution alone does not reveal the log-scale**. Both a zero-logscale rule and a non-zero-logscale rule return a properly normalised distribution object — looking at `Beta(2,1)` does not tell you that `log β = -log 2` was lost. This is why `@logscale` must be set explicitly by the rule author. + +## Inside rule bodies: `@logscale` + +When a message update rule computes a message whose normalisation constant is known analytically, it records the log-scale factor using the `@logscale` macro: + +```julia +@rule NormalMeanVariance(:out, Marginalisation) (m_μ::UnivariateNormalDistributionsFamily, m_σ²::PointMass) = begin + @logscale 0 # conjugate reparameterisation — no normalisation constant is lost + return NormalMeanVariance(mean(m_μ), var(m_μ) + mean(m_σ²)) +end +``` + +For rules where a non-trivial normalisation constant is divided out, the exact value must be provided: + +```julia +@rule Bernoulli(:p, Marginalisation) (m_out::PointMass,) = begin + @logscale log(mean(m_out)) # log-likelihood of the observed value + return Beta(...) +end +``` + +If a rule does not call `@logscale` and `LogScaleAnnotations` is active, ReactiveMP applies a fallback: if all incoming messages and marginals are `PointMass` distributions (i.e. the node is deterministic given its inputs) the log-scale is set to zero. In all other cases an error is raised to prevent silently wrong model evidence computations. + +## API + +```@docs +ReactiveMP.LogScaleAnnotations +ReactiveMP.getlogscale +ReactiveMP.@logscale +ReactiveMP.AddonLogScale +``` + +## References + +- van Erp, B., Nuijten, W. W. L., van de Laar, T., & de Vries, B. (2023). *Automating Model Comparison in Factor Graphs*. Entropy, 25(8), 1138. [https://doi.org/10.3390/e25081138](https://doi.org/10.3390/e25081138) diff --git a/docs/src/lib/approximations.md b/docs/src/lib/approximations.md new file mode 100644 index 000000000..2ea33085b --- /dev/null +++ b/docs/src/lib/approximations.md @@ -0,0 +1,99 @@ +# [Approximation methods](@id lib-approximations) + +Approximation methods are used when exact message computation is intractable — most commonly inside [Delta nodes](@ref lib-nodes-delta), where the factor is a nonlinear or non-conjugate function `y = f(x)`. Each method trades off accuracy, computational cost, and assumptions about `f`. + +All approximation methods are passed through [`DeltaMeta`](@ref) or [`FlowMeta`](@ref). + +## [Choosing a method](@id lib-approximations-choosing) + +| Method | Best for | Dimensionality | Requires | +|--------|----------|----------------|---------| +| `Linearization` | Smooth, nearly linear `f` | Any | ForwardDiff (auto) | +| `Unscented` | Smooth nonlinear `f` | Low–moderate | Nothing | +| `GaussHermiteCubature` | Univariate integrals with Gaussian inputs | Univariate | Point count `p` | +| `GaussLaguerreQuadrature` | Integrals over `[0, ∞)` | Univariate | Point count `n` | +| `srcubature` / `SphericalRadialCubature` | Multivariate Gaussian integrals | Multivariate | Nothing | +| `LaplaceApproximation` | Unimodal posteriors, differentiable `f` | Any | ForwardDiff + Optim | +| [`CVI`](@ref) | Black-box or non-differentiable `f` | Any | Optimizer + gradient | +| [`CVIProjection`](@ref) | CVI + exponential family projection | Any | ExponentialFamilyProjection.jl | +| [`ImportanceSamplingApproximation`](@ref) | General expectations via sampling | Any | Proposal distribution | + +## [Deterministic approximations](@id lib-approximations-deterministic) + +### Linearization + +[`Linearization`](@ref) approximates `f` by its first-order Taylor expansion around the current operating point. Jacobians are computed automatically using [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). This is the default method for [`FlowMeta`](@ref) and a common choice for [`DeltaMeta`](@ref) when `f` is smooth and not highly nonlinear. + +### Unscented transform + +[`Unscented`](@ref) (also [`UT`](@ref) / [`UnscentedTransform`](@ref)) propagates a deterministic set of *sigma points* through `f` and fits a Gaussian to the outputs. It captures mean and covariance through nonlinearities more accurately than linearization, without requiring derivatives. The number of sigma points scales linearly with the input dimension. + +### Gauss-Hermite cubature + +`GaussHermiteCubature` computes expectations of the form `∫ g(x) N(x; μ, σ²) dx` using a fixed set of quadrature points and weights optimized for Gaussian measures. It is exact for polynomials up to a certain degree determined by the number of points `p`: + +```julia +DeltaMeta(method = GaussHermiteCubature(21)) # 21-point rule +``` + +### Gauss-Laguerre quadrature + +`GaussLaguerreQuadrature` computes expectations over the half-line `[0, ∞)` — useful when the input has a Gamma distribution or similar semi-infinite support. + +### Spherical radial cubature + +`srcubature()` constructs a spherical-radial cubature rule for multivariate Gaussian integrals, using `2d + 1` deterministic points (where `d` is the input dimension). It provides a good balance between accuracy and cost for moderate dimensions. + +### Laplace approximation + +`LaplaceApproximation` finds the mode of the log-unnormalized posterior (using [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl)) and fits a Gaussian at that mode using the local curvature (via ForwardDiff). Best for unimodal, differentiable posteriors. + +## [Stochastic approximations](@id lib-approximations-stochastic) + +### CVI — Constrained Variational Inference + +[`CVI`](@ref) and [`ProdCVI`](@ref) approximate messages using stochastic gradient optimization of a variational objective. A gradient estimator (default: [`ForwardDiffGrad`](@ref)) computes the gradient of the log-likelihood with respect to the natural parameters of the approximating distribution. An optimizer (default: `Adam`) applies gradient steps until convergence. + +```julia +DeltaMeta(method = CVI( + rng, # random number generator + n_samples, # samples for gradient estimation + n_iterations, # gradient steps per message update + Adam(params), # optimizer +)) +``` + +!!! note + Loading [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) unlocks additional optimizers for use with CVI. See the [Extensions](@ref extra-extensions) page. + +### CVIProjection + +[`CVIProjection`](@ref) extends CVI by projecting the result onto the nearest member of a chosen exponential family, using [ExponentialFamilyProjection.jl](https://github.com/ReactiveBayes/ExponentialFamilyProjection.jl). This guarantees that the output is a valid member of the target family. + +!!! note + `CVIProjection` requires the `ExponentialFamilyProjection` package to be loaded. Without it, using `CVIProjection` in a delta node will throw an informative error. + +### Importance sampling + +[`ImportanceSamplingApproximation`](@ref) estimates expectations by drawing samples from a proposal distribution and reweighting. It is the most flexible method but converges slowly in high dimensions. + +## [API reference](@id lib-approximations-api) + +```@docs +ReactiveMP.Linearization +ReactiveMP.local_linearization +ReactiveMP.Unscented +ReactiveMP.sigma_points_weights +ReactiveMP.UT +ReactiveMP.UnscentedTransform +ReactiveMP.CVI +ReactiveMP.ProdCVI +ReactiveMP.ForwardDiffGrad +ReactiveMP.CVIProjection +ReactiveMP.CVISamplingStrategy +ReactiveMP.FullSampling +ReactiveMP.MeanBased +ReactiveMP.ProposalDistributionContainer +ReactiveMP.cvi_setup! +ReactiveMP.cvi_update! +``` diff --git a/docs/src/lib/callbacks.md b/docs/src/lib/callbacks.md new file mode 100644 index 000000000..a657d2c57 --- /dev/null +++ b/docs/src/lib/callbacks.md @@ -0,0 +1,71 @@ +# [Callbacks in the Message Passing Procedure](@id lib-callbacks) + +ReactiveMP provides a way to "hook" into the message passing procedure and listen to various events +via "callbacks". This can be useful, for example, to debug messages or monitor the order of computations. + +```@docs +ReactiveMP.Event +ReactiveMP.event_name +ReactiveMP.handle_event +ReactiveMP.invoke_callback +ReactiveMP.merge_callbacks +ReactiveMP.MergedCallbacks +``` + +## Event naming convention + +Every event in ReactiveMP is a concrete subtype of [`ReactiveMP.Event{E}`](@ref) where `E` is a `Symbol` identifying the event. +The naming convention is straightforward: for an event identified by the symbol `:event_name`, the corresponding struct is called `EventNameEvent`. +For example: + +| Symbol | Struct | +|--------|--------| +| `:before_message_rule_call` | [`ReactiveMP.BeforeMessageRuleCallEvent`](@ref) | +| `:after_product_of_two_messages` | [`ReactiveMP.AfterProductOfTwoMessagesEvent`](@ref) | +| `:before_form_constraint_applied` | [`ReactiveMP.BeforeFormConstraintAppliedEvent`](@ref) | + +Each event struct carries the relevant data as fields, so you can inspect what happened during inference. +You can use [`ReactiveMP.event_name`](@ref) to retrieve the symbol from any event type: + +```@example callbacks +using ReactiveMP #hide +ReactiveMP.event_name(ReactiveMP.BeforeProductOfTwoMessagesEvent) +``` + +To see which fields an event carries, use the standard Julia introspection: + +```julia +julia> ?ReactiveMP.BeforeProductOfTwoMessagesEvent +``` + +## Event spans + +Certain events create a "span". For example all "before" and "after" events +can be considered together. To track these relationships ReactiveMP uses the +`span_id` field in such events and uses the [`ReactiveMP.generate_span_id`](@ref) +function to generate shared ids. + +```@docs +ReactiveMP.generate_span_id +``` + +Custom callbacks can overwrite the `ReactiveMP.generate_span_id` to return `nothing` +if necessary. Note, however, that [`ReactiveMP.MergedCallbacks`](@ref) would still +use the default implementation. + +## All defined events + +Here is the list of predefined event types, to which a custom callback handler can react to. + +```@docs +ReactiveMP.BeforeMessageRuleCallEvent +ReactiveMP.AfterMessageRuleCallEvent +ReactiveMP.BeforeProductOfTwoMessagesEvent +ReactiveMP.AfterProductOfTwoMessagesEvent +ReactiveMP.BeforeProductOfMessagesEvent +ReactiveMP.AfterProductOfMessagesEvent +ReactiveMP.BeforeFormConstraintAppliedEvent +ReactiveMP.AfterFormConstraintAppliedEvent +ReactiveMP.BeforeMarginalComputationEvent +ReactiveMP.AfterMarginalComputationEvent +``` diff --git a/docs/src/lib/helpers.md b/docs/src/lib/helpers.md index 0f01fe3bf..8c0c985c6 100644 --- a/docs/src/lib/helpers.md +++ b/docs/src/lib/helpers.md @@ -1,14 +1,46 @@ # [Helper utilities](@id lib-helpers) -`ReactiveMP` implements various structures/functions/methods as "helper" structures that might be useful in various contexts. +This page documents utility types and functions that appear at the boundaries of the inference engine — primarily in custom node and rule implementations. They are not needed for everyday inference with built-in nodes, but become useful when writing [`@rule`](@ref) definitions or building new factor node types. + +## [Iteration helpers](@id lib-helpers-iteration) + +When a message update rule computes the outgoing message on edge `k` of a factor node, it needs the incoming messages from *all other edges* — every edge except `k`. The [`ReactiveMP.SkipIndexIterator`](@ref) provides an allocation-free view of a collection that skips one index. + +The constructor [`skipindex`](@ref) is the standard way to create one: + +```julia +# messages is a length-3 collection; compute outgoing message for edge 2 +# by iterating over edges 1 and 3 only +other = ReactiveMP.skipindex(messages, 2) +collect(other) # [messages[1], messages[3]] +``` + +This is used internally inside `@rule` dispatch to pass only the relevant inbound messages to the rule computation. ```@docs ReactiveMP.SkipIndexIterator ReactiveMP.skipindex +``` + +## [Macro utilities](@id lib-helpers-macro) + +The `ReactiveMP.MacroHelpers` submodule contains building blocks used by the [`@node`](@ref) and [`@rule`](@ref) macros to parse and transform Julia type expressions. These are implementation details of the macro system, but they are documented here for completeness and for users who want to understand or extend the macro infrastructure. + +| Function | Purpose | +|----------|---------| +| `ReactiveMP.MacroHelpers.ensure_symbol` | Assert that an expression is a `Symbol`; error otherwise | +| `ReactiveMP.MacroHelpers.bottom_type` | Extract the base type `T` from expressions like `Type{<:T}`, `typeof(T)`, or `T` | +| `ReactiveMP.MacroHelpers.upper_type` | Wrap a type expression into `Type{<:T}` form for dispatch | +| `ReactiveMP.MacroHelpers.proxy_type` | Wrap a type with a proxy type as `ProxyType{<:T}` | +| `ReactiveMP.MacroHelpers.@proxy_methods` | Generate forwarding method definitions for a proxy wrapper type | + +`@proxy_methods` is the most user-facing of these. It generates a set of method forwarding stubs so that a thin wrapper type transparently delegates calls to its wrapped type, without hand-writing each delegation. + +```@docs ReactiveMP.MacroHelpers.proxy_type ReactiveMP.MacroHelpers.ensure_symbol ReactiveMP.MacroHelpers.@proxy_methods ReactiveMP.MacroHelpers.upper_type ReactiveMP.MacroHelpers.bottom_type -``` \ No newline at end of file +``` diff --git a/docs/src/lib/marginal.md b/docs/src/lib/marginal.md index 08ca2e56c..12dd9ed76 100644 --- a/docs/src/lib/marginal.md +++ b/docs/src/lib/marginal.md @@ -14,7 +14,7 @@ From an implementation point a view the `Marginal` structure does nothing but ho ReactiveMP.getdata(marginal::Marginal) ReactiveMP.is_clamped(marginal::Marginal) ReactiveMP.is_initial(marginal::Marginal) -ReactiveMP.getaddons(marginal::Marginal) +ReactiveMP.getannotations(marginal::Marginal) ReactiveMP.as_marginal ReactiveMP.to_marginal ``` @@ -23,7 +23,7 @@ ReactiveMP.to_marginal using ReactiveMP, BayesBase, ExponentialFamily distribution = ExponentialFamily.NormalMeanPrecision(0.0, 1.0) -marginal = Marginal(distribution, false, true, nothing) +marginal = Marginal(distribution, false, true) ``` ```@example marginal @@ -36,4 +36,20 @@ logpdf(marginal, 1.0) ```@example marginal is_clamped(marginal), is_initial(marginal) -``` \ No newline at end of file +``` + +## Marginal observable + +Within the reactive message passing framework, marginals are not computed once and stored as values — instead they live as *streams* that continuously emit updated beliefs as new messages arrive. `MarginalObservable` is the container for such a stream. + +```@docs +ReactiveMP.MarginalObservable +``` + +Every [`ReactiveMP.AbstractVariable`](@ref) holds one `MarginalObservable`, accessed via [`ReactiveMP.get_stream_of_marginals`](@ref). The observable starts *unconnected*: its internal `LazyObservable` has no upstream source until the factor graph is activated. During activation, `ReactiveMP.connect!` wires the lazy stream to a computed source (e.g. `collectLatest` over inbound messages for a [`ReactiveMP.RandomVariable`](@ref), or the observation channel for a [`ReactiveMP.DataVariable`](@ref)). After that point, every message update propagates through the graph and the `MarginalObservable` emits a fresh `Marginal`. + +The internal `RecentSubject` ensures that: +- any subscriber that joins after the first emission immediately receives the current belief via `Rocket.getrecent` +- [`ReactiveMP.set_initial_marginal!`](@ref) can seed an initial value *before* activation, so that rules which depend on a marginal at iteration zero have something to read + +All downstream subscriptions go through the `LazyObservable`, not the subject directly, so they see the full computed stream rather than only manually pushed values. diff --git a/docs/src/lib/message.md b/docs/src/lib/message.md index 6098c4854..1174c4c42 100644 --- a/docs/src/lib/message.md +++ b/docs/src/lib/message.md @@ -41,13 +41,13 @@ Message From an implementation point a view the `Message` structure does nothing but hold some `data` object and redirects most of the statistical related functions to that `data` object. However, this object is used extensively in Julia's multiple dispatch. -Our implementation also uses extra `is_initial` and `is_clamped` fields to determine if [product of two messages](@ref lib-messages-product) results in `is_initial` or `is_clamped` posterior marginal. The final field contains the addons. These contain additional information on top of the functional form of the distribution, such as its scaling or computation history. +Our implementation also uses extra `is_initial` and `is_clamped` fields to determine if [product of two messages](@ref lib-messages-product) results in `is_initial` or `is_clamped` posterior marginal. Each message also carries an [`AnnotationDict`](@ref) for optional metadata such as log-scale factors or computation history (see [Annotations](@ref lib-annotations)). ```@docs ReactiveMP.getdata(message::Message) ReactiveMP.is_clamped(message::Message) ReactiveMP.is_initial(message::Message) -ReactiveMP.getaddons(message::Message) +ReactiveMP.getannotations(message::Message) ReactiveMP.as_message ``` @@ -55,7 +55,7 @@ ReactiveMP.as_message using ReactiveMP, BayesBase, ExponentialFamily distribution = ExponentialFamily.NormalMeanPrecision(0.0, 1.0) -message = Message(distribution, false, true, nothing) +message = Message(distribution, false, true) ``` ```@example message @@ -70,16 +70,34 @@ logpdf(message, 1.0) is_clamped(message), is_initial(message) ``` +## Message observable + +Within the reactive message passing framework, messages are not computed once and stored as values — instead each edge of the factor graph carries a *stream* that continuously emits updated messages as the inference iterates. `MessageObservable` is the container for such a stream. + +```@docs +ReactiveMP.MessageObservable +``` + +Each connection between a variable and a factor node owns one `MessageObservable`. From the variable's perspective it is an *inbound* message stream (a message arriving from a connected node); from the node's perspective the same object is the message that will eventually be used to compute the outbound message on another edge. The observable starts *unconnected*: its internal `LazyObservable` has no upstream source until the factor graph is activated. During activation, `ReactiveMP.connect!` wires the lazy stream to the result of the message update rule computation. After that point, every upstream change (a new observation, a changed prior, an iterated belief) propagates reactively through the `MessageObservable` to all its subscribers. + +The internal `RecentSubject` ensures that: +- any subscriber that joins after the first emission immediately receives the current message via `Rocket.getrecent` +- [`ReactiveMP.set_initial_message!`](@ref) can seed a value *before* activation, so that rules that read an inbound message at iteration zero have something to read + +All downstream subscriptions go through the `LazyObservable`, not the subject directly, so they see the full computed stream rather than only manually pushed values. + ### [Product of messages](@id lib-messages-product) In message passing framework, in order to compute a posterior we must compute a normalized product of two messages. -For this purpose the `ReactiveMP.jl` uses the `multiply_messages` function, which internally uses the `prod` function -defined in `BayesBase.jl` with various product strategies. We refer an interested reader to the documentation of the -`BayesBase.jl` package for more information. +For this purpose the `ReactiveMP.jl` uses the [`ReactiveMP.MessageProductContext`](@ref) structure, together with the [`ReactiveMP.compute_product_of_messages`](@ref) and [`ReactiveMP.compute_product_of_two_messages`](@ref) functions. Both functions accept a [`ReactiveMP.AbstractVariable`](@ref) as the first argument to identify which variable the product is being computed for — this is useful for callbacks (e.g. [`ReactiveMP.BeforeProductOfTwoMessagesEvent`](@ref)). The [`ReactiveMP.compute_product_of_two_messages`](@ref) function internally uses the `prod` function +defined in `BayesBase.jl` with various product strategies. We refer an interested reader to the documentation of the `BayesBase.jl` package for more information. ```@docs -ReactiveMP.multiply_messages -ReactiveMP.messages_prod_fn +ReactiveMP.MessageProductContext +ReactiveMP.compute_product_of_two_messages +ReactiveMP.compute_product_of_messages +ReactiveMP.MessagesProductFromLeftToRight +ReactiveMP.MessagesProductFromRightToLeft ``` @@ -95,4 +113,4 @@ A *message mapping* defines how messages are transformed or mapped during the pr ```@docs ReactiveMP.MessageMapping -``` \ No newline at end of file +``` diff --git a/docs/src/lib/nodes.md b/docs/src/lib/nodes.md index e1c97346c..44f90d149 100644 --- a/docs/src/lib/nodes.md +++ b/docs/src/lib/nodes.md @@ -7,22 +7,37 @@ A factor node represents a local function in a factorised representation of a ge @node ReactiveMP.FactorNode ReactiveMP.FactorNodeLocalMarginal +``` + +## [Interfaces](@id lib-node-interfaces) + +Every edge of a factor node — a connection to one variable — is represented by a [`ReactiveMP.NodeInterface`](@ref). When a `FactorNode` is constructed, one `NodeInterface` is created per edge. The constructor of `NodeInterface` immediately calls `ReactiveMP.create_new_stream_of_inbound_messages!` on the connected variable, which allocates a per-connection [`ReactiveMP.MessageObservable`](@ref) slot in the variable's `input_messages` and returns it. This observable is stored as `m_out` on the interface: it is the *outbound* message from the node's perspective (flowing toward the variable) and the *inbound* message from the variable's perspective. + +At construction time all message streams are unconnected (lazy). The actual rule computations are wired up later during graph activation (see [Activation](@ref lib-node-activation)). + +For nodes with a variable-length list of same-named edges (e.g. the `means` of a Gaussian Mixture node), [`ReactiveMP.IndexedNodeInterface`](@ref) wraps a `NodeInterface` and adds a positional index. The `ReactiveMP.ManyOf` container collects the corresponding streams for use in `@rule` dispatch; see the [Delta node](@ref lib-nodes-delta) documentation for usage examples. + +```@docs ReactiveMP.NodeInterface ReactiveMP.IndexedNodeInterface -ReactiveMP.messagein -ReactiveMP.messageout +ReactiveMP.get_stream_of_inbound_messages +ReactiveMP.get_stream_of_outbound_messages +ReactiveMP.set_stream_of_outbound_messages! ReactiveMP.tag ReactiveMP.name ReactiveMP.interfaces ReactiveMP.getvariable ReactiveMP.inputinterfaces ReactiveMP.alias_interface -ReactiveMP.collect_factorisation -ReactiveMP.collect_pipeline -ReactiveMP.collect_meta -ReactiveMP.default_meta -ReactiveMP.as_node_symbol -ReactiveMP.nodesymbol_to_nodefform +``` + +## [Activation](@id lib-node-activation) + +Graph activation is the step that connects all lazy [`ReactiveMP.MessageObservable`](@ref) and [`ReactiveMP.MarginalObservable`](@ref) streams into a live reactive network. For factor nodes this is done by calling [`ReactiveMP.activate!`](@ref) with a [`ReactiveMP.FactorNodeActivationOptions`](@ref) that bundles all inference-time configuration. + +```@docs +ReactiveMP.FactorNodeActivationOptions +ReactiveMP.activate!(::FactorNode, ::ReactiveMP.FactorNodeActivationOptions) ``` ## [Adding a custom node](@id lib-custom-node) @@ -44,6 +59,18 @@ This expression registers a new node that can be used with the inference engine. Note, however, that the `@node` macro does not generate any message passing update rules. These must be defined using the [`@rule`](@ref) macro. +## [Collecting node properties](@id lib-node-collect) + +```@docs +ReactiveMP.collect_factorisation +ReactiveMP.collect_meta +ReactiveMP.default_meta +ReactiveMP.as_node_symbol +ReactiveMP.nodesymbol_to_nodefform +ReactiveMP.FunctionalDependencies +ReactiveMP.collect_functional_dependencies +``` + ## [Node types](@id lib-node-types) We distinguish different types of factor nodes in order to have better control over Bethe Free Energy computation. @@ -85,10 +112,10 @@ println("sdtype() of `Bernoulli` node is ", sdtype(Bernoulli)) nothing #hide ``` -## [Node functional dependencies pipeline](@id lib-node-functional-dependencies-pipeline) +## [Node functional dependencies](@id lib-node-functional-dependencies) -The generic implementation of factor nodes in ReactiveMP supports custom functional dependency pipelines. Briefly, the __functional dependencies pipeline__ defines what -dependencies are need to compute a single message. As an example, consider the belief-propagation message update equation for a factor node $f$ with three edges: $x$, $y$ and $z$: +The generic implementation of factor nodes in ReactiveMP supports custom functional dependencies policies. Briefly, the __functional dependencies__ define what +dependencies are needed to compute a single message. As an example, consider the belief-propagation message update equation for a factor node $f$ with three edges: $x$, $y$ and $z$: ```math \mu(x) = \int \mu(y) \mu(z) f(x, y, z) \mathrm{d}y \mathrm{d}z @@ -102,9 +129,9 @@ Here we see that in the standard setting for the belief-propagation message out We see that in this setting, we do not need messages $\mu(y)$ and $\mu(z)$, but only the marginals $q(y)$ and $q(z)$. -## [List of functional dependencies pipelines](@id lib-node-functional-dependencies-pipelines) +## [List of functional dependencies policies](@id lib-node-functional-dependencies-policies) -The purpose of a __functional dependencies pipeline__ is to determine functional dependencies (a set of messages or marginals) that are needed to compute a single message. By default, `ReactiveMP.jl` uses so-called `DefaultFunctionalDependencies` that correctly implements belief-propagation and variational message passing schemes (including both mean-field and structured factorisations). The full list of built-in pipelines is presented below: +The purpose of a __functional dependencies__ policy is to determine functional dependencies (a set of messages or marginals) that are needed to compute a single message. By default, `ReactiveMP.jl` uses so-called `DefaultFunctionalDependencies` that correctly implements belief-propagation and variational message passing schemes (including both mean-field and structured factorisations). The full list of built-in policies is presented below: ```@docs ReactiveMP.DefaultFunctionalDependencies @@ -113,27 +140,6 @@ ReactiveMP.RequireMarginalFunctionalDependencies ReactiveMP.RequireEverythingFunctionalDependencies ``` -## [Customizing Dependencies with Metadata](@id lib-node-metadata-dependencies) - -The functional dependencies of a node can be customized at runtime using options during node activation. This allows for runtime customization of the functional dependencies, e.g. to test different message passing schemes or implement specialized behavior for specific instances of a node type: - -```julia -# Define custom dependencies based on metadata -function ReactiveMP.collect_functional_dependencies(::Type{MyNode}, options::FactorNodeActivationOptions) - if some_condition(options) # a user can specify dependencies based, for example, on metadata - return CustomDependencies() - end - # Fall back to default dependencies - return ReactiveMP.collect_functional_dependencies(MyNode, getdependecies(options)) -end - -# Use custom dependencies during activation -node = factornode(MyNode, ...) -activate!(node, FactorNodeActivationOptions(:custom_behavior, ...)) -``` - -This feature is particularly useful for testing different message passing schemes or implementing specialized behavior for specific instances of a node type. - ## [Node traits](@id lib-node-traits) Each factor node has to define the [`ReactiveMP.is_predefined_node`](@ref) trait function and to specify a [`ReactiveMP.PredefinedNodeFunctionalForm`](@ref) @@ -149,19 +155,11 @@ ReactiveMP.UndefinedNodeFunctionalForm ReactiveMP.is_predefined_node ``` -## [Node pipelines](@id lib-node-pipelines) +## [Stream postprocessors](@id lib-node-stream-postprocessors) -```@docs -ReactiveMP.AbstractPipelineStage -ReactiveMP.apply_pipeline_stage -ReactiveMP.EmptyPipelineStage -ReactiveMP.CompositePipelineStage -ReactiveMP.LoggerPipelineStage -ReactiveMP.DiscontinuePipelineStage -ReactiveMP.AsyncPipelineStage -ReactiveMP.ScheduleOnPipelineStage -ReactiveMP.schedule_updates -``` +Stream postprocessors are composable transformations applied to the reactive observables produced during activation — outbound message streams, marginal streams, and score streams. They are attached to a node via [`ReactiveMP.FactorNodeActivationOptions`](@ref) and to a random variable via [`ReactiveMP.RandomVariableActivationOptions`](@ref), and can be used for scheduling or custom instrumentation. + +See the dedicated [Stream postprocessors](@ref lib-stream-postprocessors) page for a full description and API reference. ## [List of predefined factor node](@id lib-predefined-nodes) diff --git a/docs/src/lib/nodes/ar.md b/docs/src/lib/nodes/ar.md index f840c85cc..af6d7da38 100644 --- a/docs/src/lib/nodes/ar.md +++ b/docs/src/lib/nodes/ar.md @@ -1,5 +1,52 @@ # [Autoregressive node](@id lib-nodes-ar) +The `AR` node (also exported as `Autoregressive`) encodes a **Bayesian autoregressive process** of order `p`: + +```math +y_t \sim \mathcal{N}(\theta^\top x_t, \, \gamma^{-1}) +``` + +where `yₜ` is the current observation, `xₜ = (yₜ₋₁, …, yₜ₋ₚ)` is the vector of `p` lagged values, `θ` is the vector of AR coefficients, and `γ` is the observation precision. + +This node is the natural building block for **time series models** such as AR(p), latent AR processes, and state-space models with autoregressive dynamics. + +## [Interfaces](@id lib-nodes-ar-interfaces) + +| Interface | Alias | Role | +|-----------|-------|------| +| `y` | `out` | Current observation `yₜ` | +| `x` | — | Lagged state vector `(yₜ₋₁, …, yₜ₋ₚ)` | +| `θ` | — | AR coefficient vector (length `p`) | +| `γ` | — | Observation precision (scalar) | + +## [Metadata](@id lib-nodes-ar-meta) + +`ARMeta` is required and must be passed explicitly — the node has no default meta: + +```julia +y[t] ~ AR(x[t], θ, γ) where { meta = ARMeta(Multivariate, order, ARsafe()) } +``` + +The constructor takes: +- `Univariate` or `Multivariate` — variate form (determines how `x` and `y` are interpreted). +- `order` — the AR order `p` (must equal 1 for `Univariate`). +- `ARsafe()` or `ARunsafe()` — numerical stability mode (`ARsafe` adds a small regularization to avoid singular matrices; `ARunsafe` is faster but may be numerically fragile). + +## [Univariate vs multivariate](@id lib-nodes-ar-variate) + +`ARMeta{Univariate}` treats `y` and the first element of `x` as scalars, with order forced to 1. This is an AR(1) model. + +`ARMeta{Multivariate}` uses the full companion-matrix representation to handle AR(p) for `p > 1`. The state vector `x` has length `p`, and the AR process is embedded as a linear state-space model. See [`CompanionMatrix`](@ref) for the underlying algebraic structure. + +## [State vector slicing](@id lib-nodes-ar-slicing) + +The [`ReactiveMP.ar_unit`](@ref) and [`ReactiveMP.ar_slice`](@ref) utilities extract specific parts of the joint state vector in the multivariate setting: + +- [`ReactiveMP.ar_unit`](@ref) — returns an appropriately shaped zero vector or matrix for initializing accumulators. +- [`ReactiveMP.ar_slice`](@ref) — extracts a subvector or submatrix from a joint mean/covariance. This is used inside rules to separate the `y` part from the `x` part of the joint Gaussian `q(y, x)`. + +These are internal helpers that surface when writing custom rules for AR-based models. + ```@docs ReactiveMP.ar_unit ReactiveMP.ar_slice diff --git a/docs/src/lib/nodes/bifm.md b/docs/src/lib/nodes/bifm.md index d0be43772..d70e1030d 100644 --- a/docs/src/lib/nodes/bifm.md +++ b/docs/src/lib/nodes/bifm.md @@ -1,9 +1,55 @@ # [BIFM node](@id lib-nodes-bifm) -See also [BIFM tutorial](https://reactivebayes.github.io/RxInfer.jl/stable/examples/overview/) for a comprehensive guide on using BIFM node in `RxInfer.jl`. +The **Backward Information Forward Marginals (BIFM)** node implements an efficient Kalman smoothing step for linear Gaussian state-space models. It fuses all factor contributions within a single time slice — observation likelihood, state transition, and the backward information from future time steps — into one node, enabling correct smoothed marginals without a separate backward pass. + +## [Model structure](@id lib-nodes-bifm-model) + +The BIFM node has four interfaces: + +| Interface | Role | +|-----------|------| +| `out` | Latent output (observation) of the time slice | +| `in` | Latent input to the time slice (e.g., a control signal) | +| `zprev` | Previous latent state `zₜ₋₁` | +| `znext` | Next latent state `zₜ` (carries backward information from future) | + +The state-space equations encoded by the node are: + +```math +z_t = A \, z_{t-1} + B \, u_t, \qquad x_t = C \, z_t +``` + +where `A`, `B`, and `C` are the transition, input, and output matrices stored in [`BIFMMeta`](@ref). + +## [Usage](@id lib-nodes-bifm-usage) + +The BIFM node must be used together with [`BIFMHelper`](@ref), which carries backward smoothing information between time steps. A typical model looks like: + +```julia +z_prior ~ MvNormalMeanPrecision(zeros(latent_dim), diagm(ones(latent_dim))) +z_tmp ~ BIFMHelper(z_prior) +z_prev = z_tmp + +for i in 1:nr_samples + u[i] ~ MvNormalMeanPrecision(μu, Wu) + xt[i] ~ BIFM(u[i], z_prev, z[i]) where { meta = BIFMMeta(A, B, C) } + x[i] ~ MvNormalMeanPrecision(xt[i], Wx) + z_prev = z[i] +end +``` + +!!! note + When subscribing to marginals, subscribe in the order `z`, `out`, `in` before subscribing to the free energy score function. This ordering ensures that the backward information is propagated correctly before the score is evaluated. + +## [Relationship to ContinuousTransition](@id lib-nodes-bifm-vs-ctransition) + +The [`ContinuousTransition`](@ref) node encodes a single linear-Gaussian transition `y ~ N(K(a)·x, W⁻¹)` where the transition matrix can itself be a latent variable. BIFM is a more specialized node: the matrices `A`, `B`, `C` are fixed (passed through meta), but the node efficiently handles the full time-slice factor, including the smoothing backward pass. Use `ContinuousTransition` when the transition matrix is uncertain and must be inferred; use BIFM when the structure is known and smoothing efficiency matters. + +!!! note + See also the [BIFM tutorial](https://reactivebayes.github.io/RxInfer.jl/stable/examples/overview/) in the RxInfer.jl documentation for a comprehensive guide. ```@docs ReactiveMP.BIFM ReactiveMP.BIFMMeta ReactiveMP.BIFMHelper -``` \ No newline at end of file +``` diff --git a/docs/src/lib/nodes/binomial_polya.md b/docs/src/lib/nodes/binomial_polya.md index c92407f36..a0e04002b 100644 --- a/docs/src/lib/nodes/binomial_polya.md +++ b/docs/src/lib/nodes/binomial_polya.md @@ -1,6 +1,40 @@ # [BinomialPolya node](@id lib-nodes-binomial-polya) -The BinomialPolya node implements a Binomial likelihood with logistic linear predictor and PolyaGamma augmentation for Bayesian inference. This node is particularly useful for modeling count data with overdispersion and performing Binomial regression. +The `BinomialPolya` node implements a **Binomial likelihood with a logistic linear predictor**, augmented with a Pólya-Gamma auxiliary variable for tractable Bayesian inference: + +```math +y \mid x, \beta, n \sim \mathrm{Binomial}\!\left(n,\; \sigma(x^\top \beta)\right) +``` + +where `σ` is the logistic (sigmoid) function, `x` is a feature vector, `β` is a weight vector with a Normal prior, and `n` is the number of trials. + +## [Interfaces](@id lib-nodes-binomial-polya-interfaces) + +| Interface | Role | +|-----------|------| +| `y` | Observed count (number of successes) | +| `x` | Feature vector | +| `n` | Number of trials | +| `β` | Weight vector (Normal prior) | + +## [The Pólya-Gamma augmentation trick](@id lib-nodes-binomial-polya-augmentation) + +Combining a Normal prior on `β` with a Binomial likelihood through a logistic link is not conjugate — the posterior has no closed form. The **Pólya-Gamma augmentation** (Polson et al., 2013) introduces a latent variable `ω ~ PG(n, x⊤β)` such that, conditional on `ω`, the likelihood becomes Gaussian. This makes the full-conditional update for `β` analytically tractable and allows the engine to perform exact conjugate message passing instead of sampling or variational approximation. + +This is useful for: +- **Binomial regression** — modeling count data with a logistic link. +- **Binary classification** — as a special case with `n = 1`. + +## [Meta and tuning](@id lib-nodes-binomial-polya-meta) + +`BinomialPolyaMeta` controls the Monte Carlo estimation of the average energy: + +| Field | Default | Effect | +|-------|---------|--------| +| `n_samples` | `1` | Number of samples for MC energy estimation. Increasing adds cost with diminishing accuracy benefit. | +| `rng` | `Random.default_rng()` | Random number generator. | + +If no meta is provided (`meta = nothing`), the rules use posterior means instead of sampling, which yields very similar results at no extra cost. ```@docs ReactiveMP.BinomialPolya diff --git a/docs/src/lib/nodes/ctransition.md b/docs/src/lib/nodes/ctransition.md index d8bd2843b..381138576 100644 --- a/docs/src/lib/nodes/ctransition.md +++ b/docs/src/lib/nodes/ctransition.md @@ -1,4 +1,66 @@ -# [Continous transition node](@id lib-nodes-ctransition) +# [Continuous transition node](@id lib-nodes-ctransition) + +The `ContinuousTransition` node encodes a **linear (or nonlinear) Gaussian state transition**: + +```math +y \sim \mathcal{N}(K(a) \cdot x, \, W^{-1}) +``` + +It transforms an `m`-dimensional input vector `x` into an `n`-dimensional output vector `y` via a learned matrix `K(a)`, where `a` is a latent vector and `K` is a user-supplied transformation function. The precision matrix `W` controls the amount of jitter in the transition. + +This node is the continuous-state counterpart of `DiscreteTransition` and the primary building block for **Kalman-filter-style state-space models** where the transition matrix is uncertain and must be inferred. + +## [Interfaces](@id lib-nodes-ctransition-interfaces) + +| Interface | Role | +|-----------|------| +| `y` | `n`-dimensional output state | +| `x` | `m`-dimensional input state | +| `a` | Vector parameterizing the transition matrix via `K(a)` | +| `W` | `n×n` precision matrix of the transition noise | + +## [Specifying the transformation](@id lib-nodes-ctransition-transformation) + +The transformation `K(a)` is passed through [`ContinuousTransitionMeta`](@ref) (alias: `CTMeta`). The function must return an `n×m` matrix. For example: + +```julia +# Unstructured: reshape a length-4 vector into a 2×2 matrix +transformation = a -> reshape(a, 2, 2) + +a ~ MvNormalMeanCovariance(zeros(4), Diagonal(ones(4))) +y ~ ContinuousTransition(x, a, W) where { meta = CTMeta(transformation) } +``` + +When the matrix has known structure, `K(a)` can encode it explicitly: + +```julia +# Rotation matrix parameterized by a single angle +transformation = a -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + +a ~ MvNormalMeanCovariance([0.0], [1.0;;]) +y ~ ContinuousTransition(x, a, W) where { meta = CTMeta(transformation) } +``` + +!!! note + Even for scalar transitions, `a` must be a vector (length 1). Use `MvNormal` rather than `Normal` for the prior on `a`. + +## [Factorization constraints](@id lib-nodes-ctransition-factorization) + +The node supports two factorization assumptions: + +**Mean-field** — all variables are treated as independent: +```julia +q(y, x, a, W) = q(y)q(x)q(a)q(W) +``` + +**Structured** — the joint `q(y, x)` is kept intact (useful for Kalman smoothing): +```julia +q(y, x, a, W) = q(y, x)q(a)q(W) +``` + +## [Companion matrix](@id lib-nodes-ctransition-companion) + +For autoregressive-style transitions, the companion matrix representation converts an AR coefficient vector into a state transition matrix. See [`CompanionMatrix`](@ref) in the algebra utilities and the [Autoregressive node](@ref lib-nodes-ar) for a specific application. ```@docs ReactiveMP.ContinuousTransition diff --git a/docs/src/lib/nodes/delta.md b/docs/src/lib/nodes/delta.md index 6b7bbc988..b8ac0c2e6 100644 --- a/docs/src/lib/nodes/delta.md +++ b/docs/src/lib/nodes/delta.md @@ -1,29 +1,66 @@ # [Delta node](@id lib-nodes-delta) +The delta node encodes a **deterministic functional relationship** between variables. Where a stochastic node represents `p(y | x)`, a delta node asserts that `y = f(x₁, …, xₙ)` exactly. Any Julia function `f` can be used. + +```julia +z ~ f(x, y) # z is deterministically f(x, y) +``` + +Because `f` is not a probability distribution, the standard closed-form message computation does not apply. The engine must **approximate** the outgoing messages. The approximation method is specified via [`DeltaMeta`](@ref): + +```julia +z ~ f(x, y) where { meta = DeltaMeta(method = Linearization()) } +z ~ f(x, y) where { meta = DeltaMeta(method = Unscented()) } +z ~ f(x, y) where { meta = DeltaMeta(method = CVI(...)) } +``` + +## [Choosing an approximation method](@id lib-nodes-delta-methods) + +| Method | Best for | What it needs | +|--------|----------|---------------| +| [`Linearization`](@ref) | Smooth `f` that is approximately linear near the operating point | Jacobian, computed via ForwardDiff automatically | +| [`Unscented`](@ref) / [`UT`](@ref) | Nonlinear but smooth `f` in moderate dimension | Sigma points; no derivatives required | +| [`CVI`](@ref) | Black-box or non-differentiable `f`, high dimension | Stochastic gradient estimator; requires an optimizer | +| [`CVIProjection`](@ref) | Same as `CVI` with the result projected onto an exponential family member | Same as `CVI` | +| `LaplaceApproximation` | Unimodal posteriors; `f` differentiable | Second-order Taylor expansion at the mode | + +When `f` has a known analytical inverse `f⁻¹`, you can pass it as the `inverse` keyword to skip the backward approximation entirely: + +```julia +z ~ f(x) where { meta = DeltaMeta(method = Linearization(), inverse = f_inv) } +``` + +Without an inverse, the backward (input) messages are computed via the [RTS smoother](@ref ReactiveMP.smoothRTS) (Petersen et al., 2018). + +## [Multi-input delta nodes](@id lib-nodes-delta-manyof) + +When a delta node has more than one input, the `@rule` macro receives the inputs bundled in a [`ReactiveMP.ManyOf`](@ref) container. This lets the rule dispatch on the collection of input messages rather than individually: + +```julia +@rule DeltaFn{typeof(f)}(:out, Marginalisation) ( + m_ins::ReactiveMP.ManyOf, + meta::DeltaMeta{<:Linearization}, +) = begin + # m_ins[1], m_ins[2], ... are the individual input messages + ... +end +``` + +See the [Message update rules](@ref lib-rules) page for how to define rules with `@rule`. + +!!! note + The delta node is [`Deterministic`](@ref) and does not contribute to the Bethe free energy directly. It only transforms information between variables. + +For the full API of approximation methods (CVI, Unscented, Linearization, etc.), see [Approximation methods](@ref lib-approximations). + ```@docs ReactiveMP.DeltaMeta ReactiveMP.ManyOf -ReactiveMP.Linearization -ReactiveMP.local_linearization ReactiveMP.smoothRTS -ReactiveMP.Unscented -ReactiveMP.sigma_points_weights ReactiveMP.CVIApproximationDeltaFnRuleLayout ReactiveMP.log_approximate -ReactiveMP.ForwardDiffGrad -ReactiveMP.UT -ReactiveMP.UnscentedTransform -ReactiveMP.ProdCVI -ReactiveMP.CVI -ReactiveMP.CVIProjection -ReactiveMP.CVISamplingStrategy -ReactiveMP.FullSampling -ReactiveMP.MeanBased -ReactiveMP.ProposalDistributionContainer -ReactiveMP.cvi_setup! -ReactiveMP.cvi_update! ReactiveMP.DeltaFnDefaultRuleLayout ReactiveMP.DeltaFnDefaultKnownInverseRuleLayout -ReactiveMP.SoftDot -ReactiveMP.softdot -``` \ No newline at end of file +SoftDot +softdot +``` diff --git a/docs/src/lib/nodes/discrete_transition.md b/docs/src/lib/nodes/discrete_transition.md index 97e2cae6a..6d123d7fe 100644 --- a/docs/src/lib/nodes/discrete_transition.md +++ b/docs/src/lib/nodes/discrete_transition.md @@ -1,9 +1,63 @@ +# [Discrete transition node](@id lib-nodes-discrete-transition) + +The `DiscreteTransition` node encodes a **Markov state transition** for discrete categorical variables. It represents the conditional distribution: + +```math +p(\text{out} \mid \text{in}, A) = A \cdot \text{in} +``` + +where `out` and `in` are categorical (discrete) state variables and `A` is a column-stochastic transition matrix (each column sums to one). This is the fundamental building block for **Hidden Markov Models (HMMs)** and other discrete state-space models. + +## [Interfaces](@id lib-nodes-discrete-transition-interfaces) + +The `DiscreteTransition` node accepts a variable number of inputs: + +| Interface index | Alias | Role | +|----------------|-------|------| +| 1 | `out` | Next (output) state | +| 2 | `in` | Current (input) state | +| 3 | `a` | Transition matrix variable | +| 4, 5, … | `T1`, `T2`, … | Optional additional transition matrices | + +This flexible interface allows multi-dimensional transition structures where the full transition is a product of several matrices. + +## [Comparison with a plain Categorical node](@id lib-nodes-discrete-transition-vs-categorical) + +A plain `Categorical` node fixes the probability vector at the time the node is created. `DiscreteTransition` is different in two important ways: + +1. **The transition matrix `a` is a variable** — it can have a `DirichletCollection` prior and its posterior is inferred jointly with the states. +2. **The input state is also a variable** — messages flow in both directions, making it possible to infer both past states (smoothing) and future states (prediction). + +## [Typical usage pattern](@id lib-nodes-discrete-transition-usage) + +```julia +# prior on initial state +s[1] ~ Categorical(fill(1/K, K)) + +# prior on transition matrix (one Dirichlet per column) +A ~ DirichletCollection(ones(K, K)) + +# Markov chain +for t in 2:T + s[t] ~ DiscreteTransition(s[t-1], A) +end + +# emission likelihoods +for t in 1:T + y[t] ~ Categorical(B * s[t]) +end +``` + +## [Utility functions](@id lib-nodes-discrete-transition-utils) + +The following internal functions implement the message update rules for the `DiscreteTransition` node. They are exposed for users who want to reuse them in custom rule definitions. + ```@docs -ReactiveMP.discrete_transition_decode_marginal +ReactiveMP.discrete_transition_decode_marginal ReactiveMP.discrete_transition_marginal_rule ReactiveMP.discrete_transition_process_marginals ReactiveMP.multiply_dimensions! ReactiveMP.sum_out_dimensions ReactiveMP.discrete_transition_process_messages ReactiveMP.discrete_transition_structured_message_rule -``` \ No newline at end of file +``` diff --git a/docs/src/lib/nodes/flow.md b/docs/src/lib/nodes/flow.md index 396b03e42..9fb61b698 100644 --- a/docs/src/lib/nodes/flow.md +++ b/docs/src/lib/nodes/flow.md @@ -1,6 +1,59 @@ # [Flow node](@id lib-nodes-flow) -See also [Flow tutorial](https://reactivebayes.github.io/RxInfer.jl/stable/examples/overview/) for a comprehensive guide on using flows in `RxInfer.jl`. +The flow node encodes a **normalizing flow** — a parameterized, invertible transformation that maps a simple base distribution (e.g., a Gaussian) into a complex, multimodal one. Because the transformation is invertible, both forward and backward messages can be computed without approximation. + +```julia +y ~ Flow(x) where { meta = FlowMeta(compiled_model) } +``` + +This asserts that `y = f(x)` where `f` is the composed invertible transformation defined by the flow model. The node type is [`Deterministic`](@ref). + +## [Building a flow model](@id lib-nodes-flow-model) + +A flow is assembled from **layers** stacked inside a [`FlowModel`](@ref). Each layer is an invertible mapping. ReactiveMP.jl provides: + +| Layer | Description | +|-------|-------------| +| [`PlanarFlow`](@ref) | Planar contraction/expansion along a learned direction | +| [`RadialFlow`](@ref) | Radial contraction/expansion around a learned center point | +| [`AdditiveCouplingLayer`](@ref) | Affine coupling — splits the input and transforms one half conditioned on the other | +| [`PermutationLayer`](@ref) | Permutes the input dimensions to mix information across coupling layers | +| [`InputLayer`](@ref) | Declares the input dimensionality; must be the first layer in a model | + +Layers are composed into a `FlowModel` and then **compiled** into a [`CompiledFlowModel`](@ref) before use. Compilation fixes the layer sizes and randomly initializes parameters: + +```julia +model = FlowModel(( + InputLayer(2), + AdditiveCouplingLayer(PlanarFlow()), + PermutationLayer(), + AdditiveCouplingLayer(PlanarFlow()), +)) + +compiled = compile(model) # randomly initialized parameters + +# or pass your own parameter vector: +compiled = compile(model, params) +``` + +The compiled model is then wrapped in [`FlowMeta`](@ref) and attached to the node: + +```julia +y ~ Flow(x) where { meta = FlowMeta(compiled) } +``` + +## [Approximation inside the flow node](@id lib-nodes-flow-approximation) + +By default, [`FlowMeta`](@ref) uses [`Linearization`](@ref) for any messages that require approximation (e.g., backward messages when the inverse Jacobian is unavailable in closed form). A different approximation can be passed as the second argument: + +```julia +FlowMeta(compiled, Unscented()) +``` + +## [Learning flow parameters](@id lib-nodes-flow-learning) + +!!! note + See also the [Flow tutorial](https://reactivebayes.github.io/RxInfer.jl/stable/examples/overview/) in the RxInfer.jl documentation for a complete end-to-end example. ```@docs PlanarFlow diff --git a/docs/src/lib/nodes/logical.md b/docs/src/lib/nodes/logical.md index cc38c5598..38f3ac3c0 100644 --- a/docs/src/lib/nodes/logical.md +++ b/docs/src/lib/nodes/logical.md @@ -1,4 +1,63 @@ -# [Logical operations](@id lib-nodes-logical) +# [Logical operation nodes](@id lib-nodes-logical) + +Logical nodes encode **hard Boolean constraints** between discrete binary variables. Each node represents a standard logic gate and enforces the corresponding truth table exactly — no approximation is needed for discrete variables. + +These nodes are [`Deterministic`](@ref): they do not contribute probability mass directly, but constrain the joint distribution by making the outcome of the logic gate a deterministic function of its inputs. + +## [Available operations](@id lib-nodes-logical-operations) + +| Node | Operation | Output | +|------|-----------|--------| +| [`AND`](@ref) | Logical conjunction | `out = in1 ∧ in2` | +| [`OR`](@ref) | Logical disjunction | `out = in1 ∨ in2` | +| [`IMPLY`](@ref) | Logical implication | `out = in1 ⇒ in2` | +| [`NOT`](@ref) | Logical negation | `out = ¬in` | + +All inputs and outputs are binary (Bernoulli-distributed) variables. The truth tables are: + +**AND** + +| `in1` | `in2` | `out` | +|-------|-------|-------| +| 0 | 0 | 0 | +| 0 | 1 | 0 | +| 1 | 0 | 0 | +| 1 | 1 | 1 | + +**OR** + +| `in1` | `in2` | `out` | +|-------|-------|-------| +| 0 | 0 | 0 | +| 0 | 1 | 1 | +| 1 | 0 | 1 | +| 1 | 1 | 1 | + +**IMPLY** (`in1 ⇒ in2`, equivalent to `¬in1 ∨ in2`) + +| `in1` | `in2` | `out` | +|-------|-------|-------| +| 0 | 0 | 1 | +| 0 | 1 | 1 | +| 1 | 0 | 0 | +| 1 | 1 | 1 | + +**NOT** + +| `in` | `out` | +|------|-------| +| 0 | 1 | +| 1 | 0 | + +## [When to use logical nodes](@id lib-nodes-logical-when) + +Logical nodes are useful whenever your model contains **prior structural knowledge** that relates binary events. Common use cases include: + +- Encoding that "event A occurring implies event B also occurs": `b ~ IMPLY(a, b_evidence)`. +- Building fault-tree or diagnostic models where system failures are logical combinations of component failures. +- Expressing hard constraints in discrete Bayesian networks. + +Because the constraints are exact, the resulting messages are also exact for binary inputs — no Monte Carlo or variational approximation is required. ```@docs ReactiveMP.AND diff --git a/docs/src/lib/nodes/multinomial_polya.md b/docs/src/lib/nodes/multinomial_polya.md index 4c43d295f..e8ae1a95b 100644 --- a/docs/src/lib/nodes/multinomial_polya.md +++ b/docs/src/lib/nodes/multinomial_polya.md @@ -1,6 +1,40 @@ -# [Multinomial Polya node](@id lib-nodes-multinomial-polya) +# [MultinomialPolya node](@id lib-nodes-multinomial-polya) -The MultinomialPolya node implements a Multinomial likelihood with PolyaGamma augmentation for Bayesian inference. This node is particularly useful for modeling count data with overdispersion and performing Multinomial regression. +The `MultinomialPolya` node implements a **Multinomial likelihood with a softmax linear predictor**, augmented with Pólya-Gamma auxiliary variables for tractable Bayesian inference: + +```math +x \mid N, \psi \sim \mathrm{Multinomial}\!\left(N,\; \mathrm{softmax}(\psi)\right) +``` + +where `x` is a count vector, `N` is the total number of trials, and `ψ` is a latent vector with a Normal prior. + +## [Interfaces](@id lib-nodes-multinomial-polya-interfaces) + +| Interface | Role | +|-----------|------| +| `x` | Observed count vector | +| `N` | Total number of trials | +| `ψ` | Softmax weight vector (Normal prior) | + +## [The Pólya-Gamma augmentation trick](@id lib-nodes-multinomial-polya-augmentation) + +A Normal prior on `ψ` combined with a Multinomial likelihood through a softmax link is not conjugate. The **Pólya-Gamma augmentation** (Polson et al., 2013) uses a set of latent Pólya-Gamma variables — one per category — such that, conditional on these variables, the likelihood factorizes into a product of Gaussian terms. This restores conjugacy with the Normal prior on `ψ` and allows closed-form VMP updates. + +The key advantage over Monte Carlo methods is that inference remains deterministic and converges smoothly, making it suitable for models where `ψ` is a latent variable that must be marginalized. + +Typical use cases: +- **Multinomial regression** — predicting category counts from a feature vector. +- **Topic models** — where category probabilities are the softmax of a Gaussian-distributed topic vector. + +## [Meta and tuning](@id lib-nodes-multinomial-polya-meta) + +`MultinomialPolyaMeta` controls the number of cubature points used to integrate out the Pólya-Gamma variables: + +| Field | Default | Effect | +|-------|---------|--------| +| `ncubaturepoints` | `21` ([`MULTINOMIAL_POLYA_CUBATURE_POINTS`](@ref ReactiveMP.MULTINOMIAL_POLYA_CUBATURE_POINTS)) | More points → higher accuracy, higher cost. Reduce to `7` or `9` for faster but less accurate inference. | + +The default of 21 cubature points balances accuracy and speed for typical problem sizes. ```@docs ReactiveMP.MultinomialPolya diff --git a/docs/src/lib/score.md b/docs/src/lib/score.md new file mode 100644 index 000000000..f59b2affb --- /dev/null +++ b/docs/src/lib/score.md @@ -0,0 +1,57 @@ +# [Score functions](@id lib-score) + +ReactiveMP.jl computes the **Bethe free energy** as its variational objective during inference. The free energy decomposes into local contributions from each factor node and each variable node, which are accumulated reactively as messages update. + +## [The Bethe free energy](@id lib-score-bethe) + +The Bethe free energy approximates the negative log-evidence of the model: + +```math +\mathcal{F}_{\text{Bethe}}[q] = \underbrace{\sum_f \langle -\log f \rangle_{q_f}}_{\text{average energy}} - \underbrace{\sum_f H[q_f]}_{\text{factor entropies}} + \underbrace{\sum_x (d_x - 1)\, H[q_x]}_{\text{variable entropies}} +``` + +where: +- The sum over `f` runs over all factor nodes, with `q_f` the local marginal over the factor's variables. +- The sum over `x` runs over all variable nodes, with `d_x` the degree (number of connected factors) and `q_x` the marginal of that variable. + +ReactiveMP.jl computes each term reactively: whenever a marginal changes, the local contribution is recomputed and can be accumulated by subscribing to the score streams. + +## [Score types](@id lib-score-types) + +Three tag types are used to dispatch the `score` function: + +| Type | Represents | Where used | +|------|-----------|-----------| +| `AverageEnergy` | `⟨-log f⟩_q` — the expected log-factor under the local marginal | Factor nodes | +| `DifferentialEntropy` | `-∫ q log q` — the Shannon entropy of a marginal | Factor and variable nodes | +| `FactorBoundFreeEnergy` | Local free energy contribution of one factor node | Factor nodes | +| `VariableBoundEntropy` | Scaled entropy contribution of one variable node | Variable nodes | + +The full Bethe free energy is the sum of all `FactorBoundFreeEnergy` and `VariableBoundEntropy` scores across the graph. + +## [The `score` function](@id lib-score-function) + +`score` is the central dispatch point. It is called internally by the engine, but can also be called manually for inspection: + +```julia +# Entropy of a marginal +score(DifferentialEntropy(), marginal) + +# Average energy for a factor node +score(AverageEnergy(), MyNode, Val{(:x, :y)}(), (q_x, q_y), meta) +``` + +## [Defining average energy for custom nodes](@id lib-score-average-energy) + +When adding a new factor node, the engine needs to know how to compute `⟨-log f⟩_q`. The `@average_energy` macro generates the required `score(::AverageEnergy, ...)` method: + +```julia +@average_energy MyNode (q_x::NormalMeanVariance, q_y::Gamma) begin + # return the average energy -⟨log f(x, y)⟩_{q(x)q(y)} + mx, vx = mean_var(q_x) + my = mean(q_y) + return 0.5 * log(2π) + 0.5 * (vx + mx^2) * my - ... +end +``` + +The macro handles argument naming, dispatch, and interface checking automatically. Marginals are named with a `q_` prefix matching the node interface names declared in the corresponding `@node` definition. diff --git a/docs/src/lib/stream-postprocessors.md b/docs/src/lib/stream-postprocessors.md new file mode 100644 index 000000000..9832476f2 --- /dev/null +++ b/docs/src/lib/stream-postprocessors.md @@ -0,0 +1,102 @@ +# [Stream postprocessors](@id lib-stream-postprocessors) + +A **stream postprocessor** is a composable transformation applied to one of the reactive observables produced during graph [Activation](@ref lib-node-activation). It wraps a Rocket.jl observable and returns a new observable of the same element type, leaving the message passing logic itself untouched. + +The same postprocessor can be applied to three different kinds of streams produced by the inference engine: + +- streams of **outbound messages** leaving a factor node interface or a leg of an [`ReactiveMP.EqualityChain`](@ref); +- streams of **marginals** emitted by a [`RandomVariable`](@ref) or by the local cluster of a factor node; +- streams of **scores** (free-energy contributions) used to assemble Bethe Free Energy. + +Stream postprocessors are useful for: + +- **Scheduling** — controlling *when* downstream subscribers observe updates (e.g. batching a wave of inbound observations into a single propagation step using a `PendingScheduler`, or moving work onto a worker thread using an `AsyncScheduler`). +- **Custom instrumentation** — applying any Rocket.jl operator (filtering, sampling, side-effects) on top of every stream produced by activation. + +!!! note + The previous `AbstractPipelineStage` API and the per-node `scheduler` argument have been unified into [`ReactiveMP.AbstractStreamPostprocessor`](@ref). The old `LoggerPipelineStage` is gone — equivalent behaviour can now be achieved through [callbacks](@ref lib-callbacks) without subscribing to the streams themselves. The migration guide also covers this change. + +## [Available stream postprocessors](@id lib-stream-postprocessors-available) + +| Postprocessor | Purpose | +|---------------|---------| +| `nothing` | No-op; the implicit default when no postprocessor is attached. The three `postprocess_stream_of_*` methods all have a `::Nothing` fallback that returns the stream unchanged. | +| [`ReactiveMP.ScheduleOnStreamPostprocessor`](@ref) | Redirects every emission to a Rocket.jl scheduler via the `schedule_on(scheduler)` operator. | +| [`ReactiveMP.CompositeStreamPostprocessor`](@ref) | Applies a sequence of postprocessors in order. | + +## [Composing stream postprocessors](@id lib-stream-postprocessors-compose) + +Multiple postprocessors are chained by wrapping them in a [`ReactiveMP.CompositeStreamPostprocessor`](@ref): + +```julia +postprocessor = CompositeStreamPostprocessor(( + ScheduleOnStreamPostprocessor(PendingScheduler()), + MyCustomStreamPostprocessor(), +)) +``` + +The output of stage `i` is fed as the input of stage `i + 1`, independently for each of the three stream kinds. + +## [Attaching a stream postprocessor](@id lib-stream-postprocessors-attach) + +Stream postprocessors are provided when activating a factor node via [`ReactiveMP.FactorNodeActivationOptions`](@ref) and a random variable via [`ReactiveMP.RandomVariableActivationOptions`](@ref). In practice this is done through the model specification layer (e.g. [RxInfer.jl](https://github.com/ReactiveBayes/RxInfer.jl)'s `@model` macro), but at the low level it looks like: + +```julia +postprocessor = ScheduleOnStreamPostprocessor(PendingScheduler()) + +# For a factor node +options = ReactiveMP.FactorNodeActivationOptions( + metadata, + dependencies, + postprocessor, # <-- attached to all streams produced for this node + annotations, + rulefallback, + callbacks, +) +ReactiveMP.activate!(node, options) + +# For a random variable +ReactiveMP.activate!( + var, + ReactiveMP.RandomVariableActivationOptions( + postprocessor, + ReactiveMP.MessageProductContext(), + ReactiveMP.MessageProductContext(), + ), +) +``` + +The same postprocessor instance is applied to every outbound message stream, every marginal stream, and every score stream produced by these activations. A subtype of [`ReactiveMP.AbstractStreamPostprocessor`](@ref) must therefore implement every `postprocess_stream_of_*` method that the kinds of streams it is attached to will go through; to opt out for a particular kind of stream, just forward the stream unchanged. + +## [Custom stream postprocessors](@id lib-stream-postprocessors-custom) + +Custom postprocessors are created by subtyping [`ReactiveMP.AbstractStreamPostprocessor`](@ref) and implementing one or more of [`ReactiveMP.postprocess_stream_of_outbound_messages`](@ref), [`ReactiveMP.postprocess_stream_of_marginals`](@ref), and [`ReactiveMP.postprocess_stream_of_scores`](@ref): + +```julia +using Rocket + +struct MyStreamPostprocessor <: ReactiveMP.AbstractStreamPostprocessor end + +# Postprocess outbound messages — `tap` performs a side effect and forwards +# the value unchanged. +function ReactiveMP.postprocess_stream_of_outbound_messages(::MyStreamPostprocessor, stream) + return stream |> tap(msg -> println("Intercepted: ", msg)) +end + +# Pass marginals and scores through unchanged. +ReactiveMP.postprocess_stream_of_marginals(::MyStreamPostprocessor, stream) = stream +ReactiveMP.postprocess_stream_of_scores(::MyStreamPostprocessor, stream) = stream +``` + +If a postprocessor is attached to a stream whose corresponding `postprocess_stream_of_*` method is not implemented for it, a `MethodError` is raised at activation time. To pass a kind of stream through unchanged, simply return the input stream as shown above. + +## API reference + +```@docs +ReactiveMP.AbstractStreamPostprocessor +ReactiveMP.postprocess_stream_of_outbound_messages +ReactiveMP.postprocess_stream_of_marginals +ReactiveMP.postprocess_stream_of_scores +ReactiveMP.CompositeStreamPostprocessor +ReactiveMP.ScheduleOnStreamPostprocessor +``` diff --git a/docs/src/lib/variables.md b/docs/src/lib/variables.md new file mode 100644 index 000000000..e2c96edc2 --- /dev/null +++ b/docs/src/lib/variables.md @@ -0,0 +1,140 @@ + +# [Variables](@id lib-variables) + +Variables are fundamental building blocks of a [factor graph](@ref concepts-factor-graphs). Each variable represents either a latent quantity to be inferred, an observed data point, or a fixed constant. All variable types are subtypes of [`ReactiveMP.AbstractVariable`](@ref). + +## [Choosing the right variable type](@id lib-variables-choosing) + +There are three kinds of variables, each with a distinct role: + +| Type | Constructor | Role | Can be updated? | +|------|-------------|------|----------------| +| [`ReactiveMP.RandomVariable`](@ref) | [`ReactiveMP.randomvar`](@ref) | Latent quantity to be inferred | No — inference updates its marginal | +| [`ReactiveMP.DataVariable`](@ref) | [`ReactiveMP.datavar`](@ref) | Observed quantity that receives data | Yes — via [`new_observation!`](@ref) | +| [`ReactiveMP.ConstVariable`](@ref) | [`ReactiveMP.constvar`](@ref) | Fixed constant, never changes | No — wired at construction time | + +The choice of variable type affects how the engine allocates streams and handles messages: + +- Use `randomvar` for any quantity you want to infer a posterior over. +- Use `datavar` for observations that may change between inference calls (e.g., in online or streaming settings). +- Use `constvar` for fixed hyperparameters, known constants, or any value that will never change. + +## [Variables as reactive streams](@id lib-variables-streams) + +In ReactiveMP.jl, a variable is not a single value — it is a source of reactive *streams*. Each variable holds: + +- A **marginal stream** ([`ReactiveMP.MarginalObservable`](@ref)) that emits updated [`Marginal`](@ref) beliefs as inference progresses. +- One **message stream** ([`ReactiveMP.MessageObservable`](@ref)) per connected factor node, carrying messages flowing between the variable and that node. + +These streams are *lazy*: they are allocated during [construction](@ref concepts-inference-lifecycle-construction) but carry no values until the graph is [activated](@ref concepts-inference-lifecycle-activation). After activation, feeding new data into a `datavar` triggers automatic propagation through the network, updating marginals reactively. + +See [Inference lifecycle](@ref concepts-inference-lifecycle) for an overview of the construction → activation → observation flow. + +```@docs +ReactiveMP.AbstractVariable +``` + +## [Common variable API](@id lib-variables-common) + +All variable types share a common interface for querying their kind, degree, and reactive streams. + +### Type predicates + +```@docs +ReactiveMP.israndom +ReactiveMP.isdata +ReactiveMP.isconst +ReactiveMP.degree +``` + +### Message stream allocation + +```@docs +ReactiveMP.create_new_stream_of_inbound_messages! +``` + +### Marginal and message streams + +Every variable exposes a *marginal stream* — a reactive observable that emits updated `Marginal` values as inference progresses. The stream is accessed via [`ReactiveMP.get_stream_of_marginals`](@ref) and wired up during graph activation via [`ReactiveMP.set_stream_of_marginals!`](@ref). Initial beliefs can be seeded before inference starts with [`ReactiveMP.set_initial_marginal!`](@ref), and initial messages with [`ReactiveMP.set_initial_message!`](@ref). + +```@docs +ReactiveMP.get_stream_of_marginals +ReactiveMP.set_stream_of_marginals! +ReactiveMP.set_initial_marginal! +ReactiveMP.set_initial_message! +``` + +### Prediction streams + +A *prediction stream* gives an estimate of what the variable's value would look like from the model's perspective — without conditioning on observed data for that variable. It is accessed via [`ReactiveMP.get_stream_of_predictions`](@ref) and connected during graph activation. + +```@docs +ReactiveMP.get_stream_of_predictions +ReactiveMP.set_stream_of_predictions! +``` + +## [Random variables](@id lib-variables-random) + +Random variables represent latent (unobserved) quantities in the model. During inference, messages flow through them to update the marginal belief. + +```@docs +ReactiveMP.RandomVariable +ReactiveMP.randomvar +``` + +### Stream creation + +A `RandomVariable` starts empty: its `input_messages` and `output_messages` collections are empty vectors. Each time a factor node connects to the variable, `ReactiveMP.create_new_stream_of_inbound_messages!` is called, which allocates a new `MessageObservable{AbstractMessage}`, appends it to `input_messages`, and returns it together with its index. The returned stream becomes the *outbound* message stream from the factor node's perspective (the message the node will send toward the variable). At this point, the degree equals the number of connected nodes. All streams are unconnected (lazy) until activation. + +### Activation + +```@docs +ReactiveMP.RandomVariableActivationOptions +ReactiveMP.activate!(::RandomVariable, ::RandomVariableActivationOptions) +``` + +The prediction stream for a `RandomVariable` is identical to its marginal stream, since there is no dedicated prediction channel for latent variables. + +## [Data variables](@id lib-variables-data) + +Data variables represent observed quantities. Their value is not fixed at creation time and can be updated later via [`new_observation!`](@ref). + +```@docs +ReactiveMP.DataVariable +ReactiveMP.datavar +ReactiveMP.new_observation! +``` + +### Stream creation + +A `DataVariable` has two distinct directions of information flow: + +- **Outbound (observation) stream** — a `RecentSubject{Message}` stored in `messageout`. Calling [`new_observation!`](@ref) pushes a new `Message(PointMass(value), false, false)` into this subject. Every factor node connected to the variable receives the same shared `messageout` stream as its inbound message source; `ReactiveMP.get_stream_of_outbound_messages` always returns `messageout` regardless of the connection index. +- **Inbound (backward) messages** — each connecting factor node gets its own `MessageObservable{AbstractMessage}` allocated in `input_messages` via `ReactiveMP.create_new_stream_of_inbound_messages!`, the same way as for `RandomVariable`. These carry messages flowing *back* from the graph toward the data edge. + +All streams are unconnected (lazy) until activation. + +### Activation + +```@docs +ReactiveMP.DataVariableActivationOptions +ReactiveMP.activate!(::DataVariable, ::DataVariableActivationOptions) +``` + +## [Constant variables](@id lib-variables-constant) + +Constant variables hold a fixed value, wrapped in a `PointMass` distribution. Messages from constant variables are always marked as clamped. + +```@docs +ReactiveMP.ConstVariable +ReactiveMP.constvar +``` + +### Stream creation + +Unlike `RandomVariable` and `DataVariable`, a `ConstVariable` wires up its streams at *construction* time, not during graph activation. The constructor immediately connects: + +- `messageout` to `of(Message(PointMass(constant), true, false))` — a single-element observable that emits one clamped message and completes. +- `marginal` to `of(Marginal(PointMass(constant), true, false))` — similarly fixed and clamped. + +When a factor node connects to a `ConstVariable`, `ReactiveMP.create_new_stream_of_inbound_messages!` increments the `nconnected` counter (which defines [`ReactiveMP.degree`](@ref)) and returns the *same shared* `messageout` stream for every connection. There are no per-connection inbound streams: `ReactiveMP.get_stream_of_inbound_messages` raises an error because a `ConstVariable` never receives messages from nodes. Calling [`ReactiveMP.set_stream_of_marginals!`](@ref) or [`ReactiveMP.set_stream_of_predictions!`](@ref) also raises an error, since the streams are fixed and cannot be rewired. Constant variables require no activation step. diff --git a/docs/src/migration-guides/v5-to-v6.md b/docs/src/migration-guides/v5-to-v6.md new file mode 100644 index 000000000..3856b0196 --- /dev/null +++ b/docs/src/migration-guides/v5-to-v6.md @@ -0,0 +1,280 @@ +# Migration guide: v5 to v6 + +This guide covers the breaking changes introduced in ReactiveMP.jl v6 and how to update your code. + +## Overview + +v6 introduces three major changes: + +1. **Annotations system** — the addon system is replaced by a new annotations system. Messages and marginals now carry an [`ReactiveMP.AnnotationDict`](@ref) instead of a typed tuple of addons. Annotation processors ([`ReactiveMP.AbstractAnnotations`](@ref) subtypes) handle post-processing externally. + +2. **Stream postprocessors** — the `AbstractPipelineStage` API and the per-node `scheduler` argument have been unified into a single [`ReactiveMP.AbstractStreamPostprocessor`](@ref) abstraction that postprocesses outbound message streams, marginal streams, and score streams uniformly. + +3. **Renamed API** — many internal and public functions have been renamed to be more descriptive and consistent. The old names are removed; see the tables below for the mapping. + +## Type parameter changes + +`Message` and `Marginal` each lost one type parameter: + +```julia +# v5 +Message{D, A} +Marginal{D, A} + +# v6 +Message{D} +Marginal{D} +``` + +Code that dispatches on the second type parameter (e.g. `::Message{D, Nothing}`) must be updated to use only one parameter. + +## Constructor changes + +The fourth positional argument (`addons`) is replaced by an optional [`ReactiveMP.AnnotationDict`](@ref). In most cases you can simply drop the fourth argument: + +```julia +# v5 +Message(dist, false, false, nothing) +Marginal(dist, false, false, nothing) + +# v6 +Message(dist, false, false) +Marginal(dist, false, false) +``` + +## Renamed functions + +### Variable API + +| v5 | v6 | +|---|---| +| `update!(datavar, value)` | `new_observation!(datavar, value)` | +| `getmarginal(variable)` | `ReactiveMP.get_stream_of_marginals(variable)` | +| `getmarginals(variables)` | `map(ReactiveMP.get_stream_of_marginals, variables)` | +| `getprediction(variable)` | `ReactiveMP.get_stream_of_predictions(variable)` | +| `getpredictions(variables)` | `map(ReactiveMP.get_stream_of_predictions, variables)` | +| `setmarginal!(variable, value)` | `ReactiveMP.set_initial_marginal!(variable, value)` | +| `setmarginals!(variables, values)` | `ReactiveMP.set_initial_marginal!.(variables, values)` | +| `setmessage!(variable, value)` | `ReactiveMP.set_initial_message!(variable, value)` | +| `setmessages!(variables, values)` | `ReactiveMP.set_initial_message!.(variables, values)` | + +### Node interface API + +| v5 | v6 | +|---|---| +| `messagein(interface)` | `ReactiveMP.get_stream_of_inbound_messages(interface)` | +| `messageout(interface)` | `ReactiveMP.get_stream_of_outbound_messages(interface)` | +| `create_messagein!(variable)` | `ReactiveMP.create_new_stream_of_inbound_messages!(variable)` | + +### Skip strategy API + +| v5 | v6 | +|---|---| +| `SkipInitial()` as strategy argument | `skip_initial()` as a pipe operator | +| `SkipClamped()` as strategy argument | `skip_clamped()` as a pipe operator | +| `SkipClampedAndInitial()` as strategy argument | `skip_clamped_and_initial()` as a pipe operator | +| `IncludeAll()` | *(no filter needed)* | + +The old strategies were passed to `getmarginal` as a second argument. In v6 the observable returned by `ReactiveMP.get_stream_of_marginals` is filtered directly with a pipe: + +```julia +# v5 +obs = getmarginal(variable, SkipInitial()) + +# v6 +obs = ReactiveMP.get_stream_of_marginals(variable) |> skip_initial() +``` + +### Annotation / addon functions + +| v5 | v6 | Notes | +|---|---|---| +| `getaddons(msg)` | `getannotations(msg)` | Works on both `Message` and `Marginal` | +| `getlogscale(msg)` | `getlogscale(getannotations(msg))` | No longer a direct method on messages/marginals | +| `getmemory(msg)` | `get_rule_input_arguments(getannotations(msg))` | Renamed concept: "memory" is now "input arguments" | +| `getmemoryaddon(msg)` | *removed* | Use `get_rule_input_arguments(getannotations(msg))` | + +## [From `AbstractPipelineStage` + scheduler to `AbstractStreamPostprocessor`](@id v5-to-v6-stream-postprocessors) + +The `AbstractPipelineStage` hierarchy and the separate node-level `scheduler` argument have been replaced by a single [`ReactiveMP.AbstractStreamPostprocessor`](@ref) abstraction. The new API is described in detail on the [Stream postprocessors](@ref lib-stream-postprocessors) page; this section summarises the mechanical migration. + +### Activation options + +`FactorNodeActivationOptions` lost both its `pipeline` and `scheduler` positional fields and gained a single `postprocessor` field: + +```julia +# v5 / early v6 +FactorNodeActivationOptions(metadata, dependencies, pipeline, annotations, scheduler, rulefallback, callbacks) + +# v6 +FactorNodeActivationOptions(metadata, dependencies, postprocessor, annotations, rulefallback, callbacks) +``` + +`RandomVariableActivationOptions` had its `scheduler` field renamed to `stream_postprocessor`: + +```julia +# v5 / early v6 +RandomVariableActivationOptions(AsapScheduler(), prod_context_msg, prod_context_marginal) + +# v6 +RandomVariableActivationOptions(nothing, prod_context_msg, prod_context_marginal) +# or with an explicit postprocessor: +RandomVariableActivationOptions(ScheduleOnStreamPostprocessor(PendingScheduler()), prod_context_msg, prod_context_marginal) +``` + +`nothing` is the no-op postprocessor: each `postprocess_stream_of_*` method has a `::Nothing` pass-through fallback. + +### Type and helper renaming + +| v5 | v6 | +|---|---| +| `AbstractPipelineStage` | [`ReactiveMP.AbstractStreamPostprocessor`](@ref) | +| `EmptyPipelineStage()` / `collect_pipeline(_, nothing)` | `nothing` (uses the `::Nothing` pass-through fallback) | +| `CompositePipelineStage(stages)` | [`ReactiveMP.CompositeStreamPostprocessor`](@ref)`(stages)` | +| `ScheduleOnPipelineStage(scheduler)` | [`ReactiveMP.ScheduleOnStreamPostprocessor`](@ref)`(scheduler)` | +| `apply_pipeline_stage(stage, factornode, tag, stream)` | [`ReactiveMP.postprocess_stream_of_outbound_messages`](@ref)`(postprocessor, stream)` | +| `getscheduler(options)` | `getpostprocessor(options)` | +| `getpipeline(options)` | `getpostprocessor(options)` | +| `collect_pipeline(_, ...)` | *removed* — postprocessors are passed through unchanged | +| `+` composition of stages | wrap in `CompositeStreamPostprocessor((left, right))` | + +### Removed pipeline stages + +The following pipeline stages are gone with no direct replacement: + +| Removed | Replacement | +|---|---| +| `LoggerPipelineStage` | Use [callbacks](@ref lib-callbacks) (e.g. message-product / post-rule callbacks) instead — they observe the same events without subscribing to the streams. | +| `AsyncPipelineStage` | Wrap a Rocket.jl `AsyncScheduler` in a [`ReactiveMP.ScheduleOnStreamPostprocessor`](@ref). | +| `DiscontinuePipelineStage` | Removed; was unused. Implement a custom `AbstractStreamPostprocessor` if needed. | +| `schedule_updates(vars; pipeline_stage = ...)` | Construct a [`ReactiveMP.ScheduleOnStreamPostprocessor`](@ref) and pass it via [`ReactiveMP.RandomVariableActivationOptions`](@ref). | + +### Custom pipeline stages + +If you implemented a custom `AbstractPipelineStage`, port it to `AbstractStreamPostprocessor`. The stage signature loses the `factornode` and `tag` arguments — postprocessors operate on streams uniformly and have no node context: + +```julia +# v5 / early v6 +struct MyStage <: ReactiveMP.AbstractPipelineStage end +function ReactiveMP.apply_pipeline_stage(::MyStage, factornode, tag, stream) + return stream |> ... +end + +# v6 +struct MyStreamPostprocessor <: ReactiveMP.AbstractStreamPostprocessor end +ReactiveMP.postprocess_stream_of_outbound_messages(::MyStreamPostprocessor, stream) = stream |> ... +ReactiveMP.postprocess_stream_of_marginals(::MyStreamPostprocessor, stream) = stream +ReactiveMP.postprocess_stream_of_scores(::MyStreamPostprocessor, stream) = stream +``` + +Implement only the stream kinds you actually attach the postprocessor to; if a stream of an unsupported kind reaches it during activation, a `MethodError` is raised. To pass a kind through unchanged, return the stream as-is. + +## Removed types and functions + +The following exports no longer exist in v6: + +| Removed | Replacement | +|---|---| +| `AbstractAddon` | [`ReactiveMP.AbstractAnnotations`](@ref) | +| `AddonLogScale` | [`ReactiveMP.LogScaleAnnotations`](@ref) | +| `AddonMemory` | [`ReactiveMP.InputArgumentsAnnotations`](@ref) | +| `AddonDebug` | *removed* (use [callbacks](@ref lib-callbacks) instead) | +| `multiply_addons` | [`ReactiveMP.post_product_annotations!`](@ref) | +| `@invokeaddon` | *removed* (macros like `@logscale` call `annotate!` directly) | +| `message_mapping_addons` | *removed* | +| `message_mapping_addon` | *removed* | +| `MarginalSkipStrategy` | `skip_initial()`, `skip_clamped()`, `skip_clamped_and_initial()` filter operators | +| `SkipClamped` | `skip_clamped()` | +| `SkipInitial` | `skip_initial()` | +| `SkipClampedAndInitial` | `skip_clamped_and_initial()` | +| `IncludeAll` | *(no filter needed)* | +| `apply_skip_filter` | *removed* | +| `as_marginal_observable` | *removed* | +| `AbstractVariable` | `ReactiveMP.AbstractVariable` (no longer exported) | +| `update!` | `new_observation!` | +| `getmarginal`, `getmarginals` | `ReactiveMP.get_stream_of_marginals` | +| `getprediction`, `getpredictions` | `ReactiveMP.get_stream_of_predictions` | +| `setmarginal!`, `setmarginals!` | `ReactiveMP.set_initial_marginal!` | +| `setmessage!`, `setmessages!` | `ReactiveMP.set_initial_message!` | +| `messagein` | `ReactiveMP.get_stream_of_inbound_messages` | +| `messageout` | `ReactiveMP.get_stream_of_outbound_messages` | +| `create_messagein!` | `ReactiveMP.create_new_stream_of_inbound_messages!` | +| `AbstractPipelineStage` and subtypes (`EmptyPipelineStage`, `CompositePipelineStage`, `ScheduleOnPipelineStage`, `LoggerPipelineStage`, `AsyncPipelineStage`, `DiscontinuePipelineStage`), `apply_pipeline_stage`, `collect_pipeline`, `schedule_updates`, `getpipeline`, `getscheduler` | See the [stream postprocessor migration section](@ref v5-to-v6-stream-postprocessors) above | + +## Writing custom rules + +Rules continue to return only the distribution result — no change needed for most rules. The `@logscale` macro now writes directly into the [`ReactiveMP.AnnotationDict`](@ref) instead of going through `@invokeaddon`: + +```julia +# v5 +@rule MyNode(:out, Marginalisation) (m_in::PointMass,) = begin + result = compute_something(m_in) + @logscale 0 + return result +end + +# v6 — identical, no changes needed +@rule MyNode(:out, Marginalisation) (m_in::PointMass,) = begin + result = compute_something(m_in) + @logscale 0 + return result +end +``` + +Inside a `@rule` body, `getannotations()` (previously `getaddons()`) returns the [`ReactiveMP.AnnotationDict`](@ref) for the current rule execution. The `@logscale value` macro is a shorthand for `annotate!(getannotations(), :logscale, value)`. + +## Testing rules with `@call_rule` + +The `@call_rule` macro no longer supports `return_addons` or `addons` keyword arguments. Use the `annotations` keyword to pass an [`ReactiveMP.AnnotationDict`](@ref) and read it back after the call: + +```julia +# v5 +result, addons = @call_rule [ return_addons = true ] MyNode(:out, Marginalisation) ( + m_in = PointMass(1.0), addons = (AddonLogScale(),) +) +logscale = getlogscale(addons) + +# v6 +ann = AnnotationDict() +result = @call_rule MyNode(:out, Marginalisation) (m_in = PointMass(1.0), annotations = ann,) +logscale = getlogscale(ann) +``` + +## Writing custom annotation processors + +If you had a custom `AbstractAddon` subtype, migrate it to an [`ReactiveMP.AbstractAnnotations`](@ref) subtype. See the [Annotations overview](@ref lib-annotations) for a complete guide. + +```julia +# v5 +struct MyAddon <: AbstractAddon end + +# Had to implement multiply_addons and handle tuple-based dispatch + +# v6 +struct MyAnnotations <: AbstractAnnotations end + +# Implement these two methods: +function ReactiveMP.post_rule_annotations!(::MyAnnotations, ann::AnnotationDict, mapping, messages, marginals, result) + annotate!(ann, :my_key, compute_something(result)) +end + +function ReactiveMP.post_product_annotations!(::MyAnnotations, merged::AnnotationDict, left_ann::AnnotationDict, right_ann::AnnotationDict, new_dist, left_dist, right_dist) + # Merge annotations from left and right into merged +end +``` + +See [`ReactiveMP.post_rule_annotations!`](@ref) and [`ReactiveMP.post_product_annotations!`](@ref). + +## Configuring annotations in RxInfer + +When setting up inference with RxInfer, replace addon configuration with the equivalent annotation processors: + +```julia +# v5 +addons = (AddonLogScale(), AddonMemory()) + +# v6 +annotations = (LogScaleAnnotations(), InputArgumentsAnnotations()) +``` + +Refer to the RxInfer.jl documentation for the updated inference configuration API. diff --git a/ext/ReactiveMPProjectionExt/layout/cvi_projection.jl b/ext/ReactiveMPProjectionExt/layout/cvi_projection.jl index d5e3d9c07..ce7c76132 100644 --- a/ext/ReactiveMPProjectionExt/layout/cvi_projection.jl +++ b/ext/ReactiveMPProjectionExt/layout/cvi_projection.jl @@ -6,16 +6,16 @@ import ReactiveMP: AbstractDeltaNodeDependenciesLayout, DeltaFnDefaultRuleLayout, DeltaFnNode, - getmarginal, + get_stream_of_marginals, functionalform, tag, Marginalisation, MessageMapping, DeferredMessage, with_statics, - apply_pipeline_stage, - messageout, - messagein, + postprocess_stream_of_outbound_messages, + set_stream_of_outbound_messages!, + get_stream_of_inbound_messages, connect! """ @@ -48,20 +48,20 @@ function deltafn_apply_layout( ::Val{:q_out}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), Val(:q_out), factornode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) end @@ -71,20 +71,20 @@ function deltafn_apply_layout( ::Val{:q_ins}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), Val(:q_ins), factornode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) end @@ -94,23 +94,23 @@ function deltafn_apply_layout( ::Val{:m_out}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) let interface = factornode.out msgs_names = Val{(:out,)}() - msgs_observable = combineLatestUpdates((messagein(factornode.out),), PushNew()) + msgs_observable = combineLatestUpdates((get_stream_of_inbound_messages(factornode.out),), PushNew()) marginal_names = Val{(:out, :ins)}() - marginals_observable = combineLatestUpdates((getmarginal(factornode.localmarginals.marginals[1]), getmarginal(factornode.localmarginals.marginals[2])), PushNew()) + marginals_observable = combineLatestUpdates((get_stream_of_marginals(factornode.localmarginals.marginals[1]), get_stream_of_marginals(factornode.localmarginals.marginals[2])), PushNew()) fform = functionalform(factornode) vtag = tag(interface) vconstraint = Marginalisation() - vmessageout = combineLatest( + stream_of_outbound_messages = combineLatest( (msgs_observable, marginals_observable), PushNew() ) @@ -122,23 +122,25 @@ function deltafn_apply_layout( msgs_names, marginal_names, meta, - addons, + annotations, factornode, rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap ) end - vmessageout = with_statics(factornode, vmessageout) - vmessageout = vmessageout |> map(AbstractMessage, mapping) - vmessageout = apply_pipeline_stage( - pipeline_stages, factornode, vtag, vmessageout + stream_of_outbound_messages = with_statics( + factornode, stream_of_outbound_messages ) - vmessageout = vmessageout |> schedule_on(scheduler) - - connect!(messageout(interface), vmessageout) + stream_of_outbound_messages = + stream_of_outbound_messages |> map(AbstractMessage, mapping) + stream_of_outbound_messages = postprocess_stream_of_outbound_messages( + stream_postprocessors, stream_of_outbound_messages + ) + set_stream_of_outbound_messages!(interface, stream_of_outbound_messages) end end @@ -148,19 +150,19 @@ function deltafn_apply_layout( ::Val{:m_in}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), Val(:m_in), factornode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) end diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index 3dc79f1ca..9298f975d 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -4,6 +4,7 @@ module ReactiveMP # List global dependencies here using TinyHugeNumbers, MatrixCorrectionTools, FastCholesky, LinearAlgebra using BayesBase, ExponentialFamily +using UUIDs import MatrixCorrectionTools: AbstractCorrectionStrategy, correction! @@ -21,13 +22,14 @@ include("helpers/algebra/standard_basis_vector.jl") include("constraints/form.jl") +include("callbacks.jl") +include("postprocessors.jl") +include("variable.jl") +include("annotations.jl") +include("annotations/logscale.jl") +include("annotations/input_arguments.jl") include("message.jl") include("marginal.jl") -include("addons.jl") - -include("addons/debug.jl") -include("addons/logscale.jl") -include("addons/memory.jl") """ to_marginal(any) @@ -41,20 +43,13 @@ Note: This function is a part of the private API and is not intended to be used """ to_marginal(any) = any -as_marginal(message::Message) = Marginal(to_marginal(getdata(message)), is_clamped(message), is_initial(message), getaddons(message)) -as_message(marginal::Marginal) = Message(getdata(marginal), is_clamped(marginal), is_initial(marginal), getaddons(marginal)) +as_marginal(message::Message) = Marginal(to_marginal(getdata(message)), is_clamped(message), is_initial(message), getannotations(message)) +as_message(marginal::Marginal) = Message(getdata(marginal), is_clamped(marginal), is_initial(marginal), getannotations(marginal)) getdata(::Nothing) = nothing getdata(collection::Tuple) = map(getdata, collection) getdata(collection::AbstractArray) = map(getdata, collection) -getlogscale(message::Message) = getlogscale(getaddons(message)) -getlogscale(marginal::Marginal) = getlogscale(getaddons(marginal)) -getmemoryaddon(message::Message) = getmemoryaddon(getaddons(message)) -getmemoryaddon(marginal::Marginal) = getmemoryaddon(getaddons(marginal)) -getmemory(message::Message) = getmemory(getaddons(message)) -getmemory(marginal::Marginal) = getmemory(getaddons(marginal)) - # TupleTools.prod is a more efficient version of Base.all for Tuple here is_clamped(tuple::Tuple) = TupleTools.prod(map(is_clamped, tuple)) is_initial(tuple::Tuple) = TupleTools.prod(map(is_initial, tuple)) @@ -73,22 +68,17 @@ include("approximations/unscented.jl") include("approximations/cvi.jl") include("approximations/cvi_projection.jl") +# Predefined postprocessors +include("postprocessors/scheduled.jl") + # Equality node is a special case and needs to be included before random variable implementation include("nodes/equality.jl") -include("variables/variable.jl") include("variables/random.jl") include("variables/constant.jl") include("variables/data.jl") -include("pipeline/pipeline.jl") -include("pipeline/async.jl") -include("pipeline/discontinue.jl") -include("pipeline/logger.jl") -include("pipeline/scheduled.jl") - include("nodes/nodes.jl") - include("rule.jl") include("score/score.jl") @@ -113,6 +103,27 @@ function __init__() """ println(io, errmsg) end + if exc.f === ReactiveMP.handle_event && length(argtypes) >= 2 + event_type = argtypes[2] + event_hint = if event_type <: ReactiveMP.Event + "Event{$(repr(ReactiveMP.event_name(event_type)))}" + else + string(event_type) + end + errmsg = """ + + `ReactiveMP.handle_event` was called with a callback handler of type `$(argtypes[1])` for event `$(event_type)`, but no matching method was found. This can happen if: + + 1. You implemented a custom callback handler but forgot to define `handle_event` for this specific event type. + Make sure your handler has a method like: + ReactiveMP.handle_event(::$(argtypes[1]), event::$(event_hint)) = ... + + 2. You meant to pass a `NamedTuple` as the callbacks handler but forgot the trailing comma. + In Julia, `(key = value)` is parsed as a plain assignment, not a NamedTuple. + Use `(key = value,)` (with a trailing comma) instead. + """ + println(io, errmsg) + end end end diff --git a/src/addons.jl b/src/addons.jl deleted file mode 100644 index 87f33d7dc..000000000 --- a/src/addons.jl +++ /dev/null @@ -1,63 +0,0 @@ -import MacroTools: @capture -import Base: string, show, + -import TupleTools - -abstract type AbstractAddon end - -multiply_addons(::Nothing, ::Nothing, ::Any, ::Any, ::Any) = nothing -multiply_addons(::Nothing, addon::Any, ::Any, ::Missing, ::Any) = addon -multiply_addons(addon::Any, ::Nothing, ::Any, ::Any, ::Missing) = addon -multiply_addons(::Nothing, ::Nothing, ::Any, ::Missing, ::Any) = nothing -multiply_addons(::Nothing, ::Nothing, ::Any, ::Any, ::Missing) = nothing -multiply_addons(::Nothing, ::Nothing, ::Any, ::Missing, ::Missing) = nothing - -function multiply_addons( - left_addons::Tuple, right_addons::Tuple, new_dist, left_dist, right_dist -) - - # perform sanity check on the length of the addons - @assert length(left_addons) == length(right_addons) "Trying to perform computations with different lengths of addons." - - # compute addon product elementwise - return map(left_addons, right_addons) do left_addon, right_addon - multiply_addons( - left_addon, right_addon, new_dist, left_dist, right_dist - ) - end -end - -# Nice functionality, allows to write `addons = Addon1() + Addon2() + ...` -+(left::AbstractAddon, right::AbstractAddon) = (left, right) -+(left::NTuple{N, AbstractAddon}, right::AbstractAddon) where {N} = ( - left..., right -) -+(left::AbstractAddon, right::NTuple{N, AbstractAddon}) where {N} = ( - left, right... -) - -macro invokeaddon(Type, callback) - # invoke addon macro can be executed only inside the @rule macro - index = gensym(:index) - addon = gensym(:addon) - value = gensym(:value) - body = quote - # First we check is the `_addons` field of the `@rule` macro has the associated addon enabled - # To do that we check if the type of the addon is present in the `_addons` tuple - _addons = if !isnothing(_addons) - # If the specified addons is enabled we find its index and replace the value with the specified body - local $index = findnext((addon) -> addon isa $(Type), _addons, 1) - if !isnothing($index) - local $addon = () -> $(Type)($(callback)) - local $value = $(addon)() - # Here we replace the previous value of the addon at the specified index - ReactiveMP.TupleTools.insertat(_addons, $index, ($value,)) - else - _addons - end - else - # If the addon is not present we simply return the result - _addons - end - end - return esc(body) -end diff --git a/src/addons/debug.jl b/src/addons/debug.jl deleted file mode 100644 index 4e252b95d..000000000 --- a/src/addons/debug.jl +++ /dev/null @@ -1,88 +0,0 @@ -export AddonDebug - -""" - AddonDebug(f :: Function) - -This addon calls the function `f` over the output of the message mapping and products. The result is expected to be boolean and when returning true, it will throw an error with the debug information. Common applications of this addon are to check for NaNs and Infs in the messages and marginals. - -## Example -```julia -checkfornans(x) = isnan(x) -checkfornans(x::AbstractArray) = any(checkfornans.(x)) -checkfornans(x::Tuple) = any(checkfornans.(x)) - -addons = (AddonDebug(dist -> checkfornans(params(dist))),) -``` -""" -struct AddonDebug <: AbstractAddon - f::Function -end - -AddonDebug() = AddonDebug(nothing) - -getdebugaddon(addons::NTuple{N, AbstractAddon}) where {N} = first( - filter(x -> typeof(x) <: AddonDebug, addons) -) - -(addon::AddonDebug)(x) = addon.f(x) - -function message_mapping_addon( - addon::AddonDebug, mapping, messages, marginals, result -) - if addon(result) - - # create error message - msg = "Debug addon triggered:\n" - msg *= "Mapping:\n" - msg *= "At the node: " * string(message_mapping_fform(mapping)) * "\n" - msg *= "Towards interface: " * string(mapping.vtag) * "\n" - msg *= "With local constraint: " * string(mapping.vconstraint) * "\n" - if !isnothing(mapping.meta) - msg *= "With meta: " * string(mapping.meta) * "\n" - end - if !isnothing(mapping.addons) - msg *= "With addons: " * string(mapping.addons) * "\n" - end - if !isnothing(messages) - msg *= "Incoming messages:\n" - for message in messages - msg *= string(message) * "\n" - end - end - if !isnothing(marginals) - msg *= "Incoming marginals:\n" - for marginal in marginals - msg *= string(marginal) * "\n" - end - end - msg *= "Result:\n" - msg *= string(result) - - # throw error - return error(msg) - end - return addon -end - -function multiply_addons( - left_addon::AddonDebug, - right_addon::AddonDebug, - new_dist, - left_dist, - right_dist, -) - if left_addon(new_dist) - - # create error message - msg = "Debug addon triggered:\n" - msg *= "Incoming distributions: \n" - msg *= string(left_dist) * "\n" - msg *= string(right_dist) * "\n" - msg *= "Resulting distribution: \n" - msg *= string(new_dist) - - # throw error - return error(msg) - end - return left_addon -end diff --git a/src/addons/logscale.jl b/src/addons/logscale.jl deleted file mode 100644 index 16a75ecb9..000000000 --- a/src/addons/logscale.jl +++ /dev/null @@ -1,98 +0,0 @@ -export AddonLogScale, getlogscale - -using Distributions - -import Base: prod, string - -struct AddonLogScale{T} <: AbstractAddon - logscale::T -end - -AddonLogScale() = AddonLogScale(nothing) - -getlogscale(addon::AddonLogScale) = addon.logscale -getlogscale(::Nothing) = error( - "Log-scale addon is not available. Make sure to include AddonLogScale in the addons. Currently, log scale factors are only supported for very specific nodes and messages in sum-product updates. Extensions to variational message passing are not yet supported.", -) - -function getlogscale(addons::NTuple{N, AbstractAddon}) where {N} - logscales = filter(addon -> addon isa AddonLogScale, addons) - if length(logscales) === 0 - error( - "Log-scale addon is not available. Make sure to include AddonLogScale in the addons.", - ) - end - return mapreduce(getlogscale, +, logscales) -end - -function message_mapping_addon( - ::AddonLogScale{Nothing}, mapping, messages, marginals, result::Distribution -) - # Here we assume - # 1. If log-scale value has not been computed during the message update rule - # 2. Either all messages or marginals are of type PointMass - # 3. The result of the message update rule is a proper distribution - # THEN: logscale is equal to zero - # OTHERWISE: show an error - # This logic probably can be improved, e.g. if some tracks conjugacy between the node and messages - if isnothing(marginals) && all(data -> data isa PointMass, messages) - return AddonLogScale(0) - elseif isnothing(messages) && all(data -> data isa PointMass, marginals) - return AddonLogScale(0) - else - error( - "Log-scale value has not been computed for the message update rule = $(mapping)", - ) - end -end - -# Log scale macro for the message update rules -macro logscale(lambda) - @capture(lambda, (body_)) || - error("Error in macro. Lambda body specification is incorrect") - # return expression for @logscale - return esc(:(ReactiveMP.@invokeaddon AddonLogScale $body)) -end - -function multiply_addons( - left_addon::AddonLogScale{Missing}, - right_addon::AddonLogScale, - new_dist, - left_dist::Missing, - right_dist, -) - return right_addon -end - -function multiply_addons( - left_addon::AddonLogScale, - right_addon::AddonLogScale{Missing}, - new_dist, - left_dist, - right_dist::Missing, -) - return left_addon -end - -function multiply_addons( - left_addon::AddonLogScale, - right_addon::AddonLogScale, - new_dist, - left_dist, - right_dist, -) - - # fetch log scales from addons - left_logscale = getlogscale(left_addon) - right_logscale = getlogscale(right_addon) - - # compute new logscale - new_logscale = compute_logscale(new_dist, left_dist, right_dist) - - # return updated logscale addon - return AddonLogScale(left_logscale + right_logscale + new_logscale) -end - -function string(addon::AddonLogScale) - return string("log-scale = ", getlogscale(addon), "; ") -end diff --git a/src/addons/memory.jl b/src/addons/memory.jl deleted file mode 100644 index dd86a98be..000000000 --- a/src/addons/memory.jl +++ /dev/null @@ -1,127 +0,0 @@ -export AddonMemory, getmemory - -import Base: prod, string, show - -struct AddonMemory{T} <: AbstractAddon - memory::T -end - -AddonMemory() = AddonMemory(nothing) - -getmemoryaddon(addons::NTuple{N, AbstractAddon}) where {N} = first( - filter(x -> typeof(x) <: AddonMemory, addons) -) -getmemory(addon::AddonMemory) = addon.memory -getmemory(addons::NTuple{N, AbstractAddon}) where {N} = - getmemoryaddon(addons).memory - -struct AddonMemoryMessageMapping{M <: MessageMapping, S, L, R} - mapping :: M - messages :: S - marginals :: L - result :: R -end - -struct AddonMemoryProd{T} - mappings::Vector{T} -end - -function message_mapping_addon( - ::AddonMemory{Nothing}, mapping, messages, marginals, result -) - return AddonMemory( - AddonMemoryMessageMapping(mapping, messages, marginals, result) - ) -end - -function multiply_addons( - left_addon::AddonMemory, - right_addon::AddonMemory, - new_dist, - left_dist, - right_dist, -) - return AddonMemory( - construct_memory(getmemory(left_addon), getmemory(right_addon)) - ) -end - -function construct_memory( - left::AddonMemoryMessageMapping, right::AddonMemoryMessageMapping -) - return AddonMemoryProd(Any[left, right]) -end - -function construct_memory( - left::AddonMemoryMessageMapping, right::AddonMemoryProd -) - pushfirst!(right.mappings, left) - return right -end - -function construct_memory( - left::AddonMemoryProd, right::AddonMemoryMessageMapping -) - push!(left.mappings, right) - return left -end - -function construct_memory(left::AddonMemoryProd, right::AddonMemoryProd) - append!(left.mappings, right.mappings) - return left -end - -function string(::AddonMemory) - return string("memory present; ") -end - -show(io::IO, addon::AddonMemory) = print( - io, string("AddonMemory(", addon.memory, ")") -) - -function show(io::IO, addon::AddonMemoryMessageMapping) - indent = get(io, :indent, 0) - println(io, ' ', "Message mapping memory:") - println( - io, ' '^indent, "At the node: ", message_mapping_fform(addon.mapping) - ) - println(io, ' '^indent, "Towards interface: ", addon.mapping.vtag) - println( - io, ' '^indent, "With local constraint: ", addon.mapping.vconstraint - ) - if !isnothing(addon.mapping.meta) - println(io, ' '^indent, "With meta: ", addon.mapping.meta) - end - if !isnothing(addon.mapping.addons) - println(io, ' '^indent, "With addons: ", addon.mapping.addons) - end - if !isnothing(addon.messages) - println( - io, - ' '^indent, - "With input messages on ", - addon.mapping.msgs_names, - " edges: ", - addon.messages, - ) - end - if !isnothing(addon.marginals) - println( - io, - ' '^indent, - "With input marginals on ", - addon.mapping.marginals_names, - " edges: ", - addon.marginals, - ) - end - println(io, ' '^indent, "With the result: ", addon.result) -end - -function show(io::IO, addon::AddonMemoryProd) - indent = get(io, :indent, 0) - println(io, ' '^indent, "Product memory:") - for message in addon.mappings - show(IOContext(io, :indent => indent + 4), message) - end -end diff --git a/src/annotations.jl b/src/annotations.jl new file mode 100644 index 000000000..8f65af455 --- /dev/null +++ b/src/annotations.jl @@ -0,0 +1,170 @@ +export getannotations + +""" + AnnotationDict() + AnnotationDict(other::AnnotationDict) + +A mutable dictionary that associates `Symbol` keys with arbitrary annotation values. +Supports lazy initialization — no memory is allocated until the first write. + +The copy constructor creates an independent shallow copy of `other`. +""" +mutable struct AnnotationDict + data::Union{Nothing, Dict{Symbol, Any}} + + function AnnotationDict() + return new(nothing) + end + + function AnnotationDict(other::AnnotationDict) + return new( + if isnothing(other.data) + nothing + else + copy(other.data::Dict{Symbol, Any}) + end, + ) + end +end + +# Overloaded later for `::Message` and `::Marginal` in their respective files +function getannotations end + +Base.isempty(ann::AnnotationDict) = + isnothing(ann.data) || isempty(ann.data::Dict{Symbol, Any}) + +function Base.show(io::IO, ann::AnnotationDict) + if isempty(ann) + print(io, "AnnotationDict()") + else + print(io, "AnnotationDict(") + join(io, ("$k => $v" for (k, v) in ann.data::Dict{Symbol, Any}), ", ") + print(io, ")") + end +end + +function Base.:(==)(left::AnnotationDict, right::AnnotationDict) + return left.data == right.data +end + +""" + has_annotation(ann::AnnotationDict, key::Symbol) -> Bool + +Return `true` if `ann` contains an entry for `key`, `false` otherwise. +""" +function has_annotation(ann::AnnotationDict, key::Symbol) + return !isnothing(ann.data) && haskey(ann.data::Dict{Symbol, Any}, key) +end + +""" + annotate!(ann::AnnotationDict, key::Symbol, value) + +Store `value` under `key` in `ann`. Always returns `nothing`. +""" +function annotate!(ann::AnnotationDict, key::Symbol, value) + if isnothing(ann.data) + data = Dict{Symbol, Any}(key => value) + ann.data = data + else + (ann.data::Dict{Symbol, Any})[key] = value + end + return nothing +end + +""" + get_annotation(ann::AnnotationDict, key::Symbol) + +Return the value stored under `key`. Throws `KeyError` if `key` is absent. +""" +function get_annotation(ann::AnnotationDict, key::Symbol) + if isnothing(ann.data) + throw(KeyError(key)) + end + return (ann.data::Dict{Symbol, Any})[key] +end + +""" + get_annotation(ann::AnnotationDict, ::Type{T}, key::Symbol) where {T} + +Return the value stored under `key`, converted to type `T`. Throws `KeyError` if +`key` is absent. +""" +function get_annotation(ann::AnnotationDict, ::Type{T}, key::Symbol) where {T} + return convert(T, get_annotation(ann, key))::T +end + +""" + AbstractAnnotations + +Abstract base type for annotation processors. Subtypes define how annotations +are written into messages after rule execution and merged during message products. + +See also: [`post_product_annotations!`](@ref), [`post_rule_annotations!`](@ref) +""" +abstract type AbstractAnnotations end + +""" + post_product_annotations!(processor::AbstractAnnotations, merged::AnnotationDict, left_ann::AnnotationDict, right_ann::AnnotationDict, new_dist, left_dist, right_dist) + +Write annotations into `merged` based on `left_ann`, `right_ann`, and the distributions +involved in the message product. Called once per processor inside +[`compute_product_of_two_messages`](@ref). +""" +function post_product_annotations! end + +""" + pre_rule_annotations!(processor::AbstractAnnotations, ann::AnnotationDict, mapping, messages, marginals) + +Write annotations into `ann` before a rule has executed. Called once per processor +inside the `MessageMapping` callable, before the rule returns its result distribution. +""" +function pre_rule_annotations! end + +""" + post_rule_annotations!(processor::AbstractAnnotations, ann::AnnotationDict, mapping, messages, marginals, result) + +Write annotations into `ann` after a rule has executed. Called once per processor +inside the `MessageMapping` callable, after the rule returns its result distribution. +""" +function post_rule_annotations! end + +""" + post_product_annotations!(processors, left_ann::AnnotationDict, right_ann::AnnotationDict, new_dist, left_dist, right_dist) -> AnnotationDict + +Produce a merged `AnnotationDict` from the annotations of two messages being multiplied. +Called inside [`compute_product_of_two_messages`](@ref). + +If `left_dist` is `missing` the right annotations are copied through unchanged, and vice versa. +If both are `missing`, or if `processors` is `nothing`, an empty `AnnotationDict` is returned. +Otherwise each processor in `processors` is called via the per-processor `post_product_annotations!` +to populate the result. +""" +function post_product_annotations!( + processors, + left_ann::AnnotationDict, + right_ann::AnnotationDict, + new_dist, + left_dist, + right_dist, +) + merged = AnnotationDict() + if isnothing(processors) + return merged + end + for p in processors + post_product_annotations!( + p, merged, left_ann, right_ann, new_dist, left_dist, right_dist + ) + end + return merged +end + +post_product_annotations!(processors, left_ann::AnnotationDict, right_ann::AnnotationDict, new_dist, ::Missing, ::Missing) = AnnotationDict() + +post_product_annotations!(processors, left_ann::AnnotationDict, right_ann::AnnotationDict, new_dist, ::Missing, right_dist) = AnnotationDict( + right_ann +) + +post_product_annotations!(processors, left_ann::AnnotationDict, right_ann::AnnotationDict, new_dist, left_dist, ::Missing) = AnnotationDict( + left_ann +) diff --git a/src/annotations/input_arguments.jl b/src/annotations/input_arguments.jl new file mode 100644 index 000000000..624ed6294 --- /dev/null +++ b/src/annotations/input_arguments.jl @@ -0,0 +1,184 @@ +export InputArgumentsAnnotations, + RuleInputArgumentsRecord, + ProductInputArgumentsRecord, + get_rule_input_arguments, + AddonMemory + +""" + RuleInputArgumentsRecord + +Stores the inputs and result of a single message update rule execution: the +`MessageMapping`, the incoming messages tuple, the incoming marginals tuple, and +the computed result distribution. +""" +struct RuleInputArgumentsRecord + mapping + messages + marginals + result +end + +""" + ProductInputArgumentsRecord + +Stores the collection of [`RuleInputArgumentsRecord`](@ref) objects that were +combined during one or more message products. Each element corresponds to one +rule execution that contributed to the product. +""" +struct ProductInputArgumentsRecord + mappings::Vector{RuleInputArgumentsRecord} +end + +""" + InputArgumentsAnnotations <: AbstractAnnotations + +Annotation processor that records the input arguments and result of each +message update rule execution and propagates them through message products. + +After a rule executes, stores a [`RuleInputArgumentsRecord`](@ref) under the +`:rule_input_arguments` key of the annotation dict. During message products, +merges the records from the two sides into a [`ProductInputArgumentsRecord`](@ref). +""" +struct InputArgumentsAnnotations <: AbstractAnnotations end + +""" + get_rule_input_arguments(ann::AnnotationDict) + +Return the rule input arguments stored in `ann`. The value is a +[`RuleInputArgumentsRecord`](@ref) when the message came directly from a single +rule execution, or a [`ProductInputArgumentsRecord`](@ref) when it is the result +of one or more message products. Throws `KeyError` if the annotation has not been +set. +""" +get_rule_input_arguments(ann::AnnotationDict) = get_annotation( + ann, :rule_input_arguments +) + +function pre_rule_annotations!( + ::InputArgumentsAnnotations, + ann::AnnotationDict, + mapping, + messages, + marginals, +) + return nothing +end + +function post_rule_annotations!( + ::InputArgumentsAnnotations, + ann::AnnotationDict, + mapping, + messages, + marginals, + result, +) + annotate!( + ann, + :rule_input_arguments, + RuleInputArgumentsRecord(mapping, messages, marginals, result), + ) + return nothing +end + +function _merge_input_arguments( + left::RuleInputArgumentsRecord, right::RuleInputArgumentsRecord +) + return ProductInputArgumentsRecord(RuleInputArgumentsRecord[left, right]) +end + +function _merge_input_arguments( + left::RuleInputArgumentsRecord, right::ProductInputArgumentsRecord +) + pushfirst!(right.mappings, left) + return right +end + +function _merge_input_arguments( + left::ProductInputArgumentsRecord, right::RuleInputArgumentsRecord +) + push!(left.mappings, right) + return left +end + +function _merge_input_arguments( + left::ProductInputArgumentsRecord, right::ProductInputArgumentsRecord +) + append!(left.mappings, right.mappings) + return left +end + +function post_product_annotations!( + ::InputArgumentsAnnotations, + merged::AnnotationDict, + left_ann::AnnotationDict, + right_ann::AnnotationDict, + new_dist, + left_dist, + right_dist, +) + left_record = get_rule_input_arguments(left_ann) + right_record = get_rule_input_arguments(right_ann) + annotate!( + merged, + :rule_input_arguments, + _merge_input_arguments(left_record, right_record), + ) + return nothing +end + +function Base.show(io::IO, record::RuleInputArgumentsRecord) + indent = get(io, :indent, 0) + pad = ' '^indent + mapping = record.mapping + println(io, pad, "Rule input arguments:") + println(io, pad, " node: ", message_mapping_fform(mapping)) + println(io, pad, " interface: ", mapping.vtag) + println(io, pad, " constraint: ", mapping.vconstraint) + if !isnothing(mapping.meta) + println(io, pad, " meta: ", mapping.meta) + end + if !isnothing(record.messages) + names = unval(mapping.msgs_names) + for (name, msg) in zip(names, record.messages) + println(io, pad, " msg(", name, ") = ", msg) + end + end + if !isnothing(record.marginals) + names = unval(mapping.marginals_names) + for (name, mar) in zip(names, record.marginals) + println(io, pad, " q(", name, ") = ", mar) + end + end + print(io, pad, " result: ", record.result) +end + +function Base.show(io::IO, record::ProductInputArgumentsRecord) + indent = get(io, :indent, 0) + pad = ' '^indent + println( + io, + pad, + "Product of ", + length(record.mappings), + " rule input arguments:", + ) + inner = IOContext(io, :indent => indent + 4) + for (i, r) in enumerate(record.mappings) + println(inner, pad, " [", i, "]") + show(inner, r) + i < length(record.mappings) && println(io) + end +end + +""" + AddonMemory(args...; kwargs...) + +Deprecated: `AddonMemory` has been removed in ReactiveMP v6. +Use [`InputArgumentsAnnotations`](@ref) instead. See the migration guide in the documentation for details. +""" +function AddonMemory(args...; kwargs...) + error( + """`AddonMemory` has been removed in ReactiveMP v6 and replaced by `InputArgumentsAnnotations`. """ * + """See the migration guide in the documentation for details.""", + ) +end diff --git a/src/annotations/logscale.jl b/src/annotations/logscale.jl new file mode 100644 index 000000000..6d780b4d0 --- /dev/null +++ b/src/annotations/logscale.jl @@ -0,0 +1,86 @@ +export LogScaleAnnotations, getlogscale, AddonLogScale + +""" + LogScaleAnnotations <: AbstractAnnotations + +Annotation processor that tracks the log-scale factor of a message. +Writes the `:logscale` annotation during rule execution (via `@logscale`) and +merges it across message products by summing the left and right log-scales and +adding the normalisation correction from `compute_logscale`. +""" +struct LogScaleAnnotations <: AbstractAnnotations end + +""" + getlogscale(ann::AnnotationDict) + +Return the log-scale value stored in `ann`. Throws `KeyError` if the logscale +annotation has not been set. +""" +getlogscale(ann::AnnotationDict) = get_annotation(ann, :logscale) + +""" + @logscale value + +Set the log-scale annotation on the current rule's annotation dict. +Intended to be called inside a `@rule` body. Expands to +`annotate!(getannotations(), :logscale, value)`. +""" +macro logscale(value) + return esc(:(ReactiveMP.annotate!(getannotations(), :logscale, $(value)))) +end + +function pre_rule_annotations!( + ::LogScaleAnnotations, ann::AnnotationDict, mapping, messages, marginals +) + return nothing +end + +function post_rule_annotations!( + ::LogScaleAnnotations, + ann::AnnotationDict, + mapping, + messages, + marginals, + result, +) + has_annotation(ann, :logscale) && return nothing + if isnothing(marginals) && all(m -> getdata(m) isa PointMass, messages) + annotate!(ann, :logscale, 0) + elseif isnothing(messages) && all(m -> getdata(m) isa PointMass, marginals) + annotate!(ann, :logscale, 0) + else + error( + "Log-scale annotation has not been set for the message update rule = $(mapping)", + ) + end + return nothing +end + +function post_product_annotations!( + ::LogScaleAnnotations, + merged::AnnotationDict, + left_ann::AnnotationDict, + right_ann::AnnotationDict, + new_dist, + left_dist, + right_dist, +) + left_logscale = getlogscale(left_ann) + right_logscale = getlogscale(right_ann) + new_logscale = compute_logscale(new_dist, left_dist, right_dist) + annotate!(merged, :logscale, left_logscale + right_logscale + new_logscale) + return nothing +end + +""" + AddonLogScale(args...; kwargs...) + +Deprecated: `AddonLogScale` has been removed in ReactiveMP v6. +Use [`LogScaleAnnotations`](@ref) instead. See the migration guide in the documentation for details. +""" +function AddonLogScale(args...; kwargs...) + error( + """`AddonLogScale` has been removed in ReactiveMP v6 and replaced by `LogScaleAnnotations`. """ * + """See the migration guide in the documentation for details.""", + ) +end diff --git a/src/callbacks.jl b/src/callbacks.jl new file mode 100644 index 000000000..5ee60ff65 --- /dev/null +++ b/src/callbacks.jl @@ -0,0 +1,420 @@ +using UUIDs + +""" + Event{E} + +Abstract supertype for all callback events in the reactive message passing procedure. +`E` is a `Symbol` that identifies the event, e.g. `Event{:before_message_rule_call}`. + +Concrete event types should subtype `Event{:event_name}` and carry the relevant data as fields. +The naming convention is that for an event `:event_name`, the corresponding struct is called `EventNameEvent`. + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.handle_event`](@ref) +""" +abstract type Event{E} end + +""" + event_name(::Type{<:Event{E}}) where {E} + event_name(event::Event) + +Returns the event name symbol `E` from an `Event{E}` type, subtype, or instance. +""" +event_name(::Type{<:Event{E}}) where {E} = E +event_name(event::Event) = event_name(typeof(event)) + +""" + handle_event(handler, event::Event) + +Custom callback handlers should implement `handle_event` to listen to events +during the reactive message passing procedure. +Each event is a subtype of [`ReactiveMP.Event{E}`](@ref) that carries the relevant data as fields. +The return value of `handle_event` is ignored. To communicate state changes, use mutable event fields. + +```jldoctest +julia> struct MyEvent <: ReactiveMP.Event{:my_event} + value::Int + end; + +julia> struct MyCustomCallbackHandler end; + +julia> ReactiveMP.handle_event(::MyCustomCallbackHandler, event::MyEvent) = print("Event value: \$(event.value)"); +``` + +See also: [`ReactiveMP.Event`](@ref), [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.merge_callbacks`](@ref) +""" +function handle_event end + +""" + invoke_callback(callbacks, event::Event) + +Invokes the callback handler(s) for the given event and returns the `event` itself. +Internally dispatches to [`ReactiveMP.handle_event`](@ref) for each handler. +Does nothing and returns the event if `callbacks` is `nothing`. + +See also: [`ReactiveMP.handle_event`](@ref), [`ReactiveMP.Event`](@ref), [`ReactiveMP.merge_callbacks`](@ref) +""" +function invoke_callback(callbacks::Nothing, event::Event) + return event +end + +""" + invoke_callback(callbacks::NamedTuple, event::Event{E}) + +The `callbacks` can also be a `NamedTuple` with fields corresponding to event names. +Each callback function receives the event object itself. The return value of the callback is ignored. + +```jldoctest +julia> mutable struct CountEvent <: ReactiveMP.Event{:count_event} + count::Int + end; + +julia> callbacks = (count_event = (event) -> event.count += 1,); + +julia> event = CountEvent(0); + +julia> ReactiveMP.invoke_callback(callbacks, event); + +julia> event.count +1 +``` + +If the `NamedTuple` does not have a field corresponding to the event name, the event will be ignored. +""" +function invoke_callback(callbacks::NamedTuple{K}, event::Event{E}) where {K, E} + if E in K + callbacks[E](event) + end + return event +end + +""" + invoke_callback(callbacks::Dict{Symbol}, event::Event{E}) + +The `callbacks` can also be a `Dict{Symbol, Any}` with keys corresponding to event names. +Works the same as the `NamedTuple` variant, but allows dynamic construction of callback handlers at runtime. +Each callback function receives the event object itself. The return value of the callback is ignored. + +If the `Dict` does not have a key corresponding to the event name, the event will be ignored. +""" +function invoke_callback(callbacks::Dict{Symbol}, event::Event{E}) where {E} + if haskey(callbacks, E) + callbacks[E](event) + end + return event +end + +""" + invoke_callback(handler, event::Event) + +Fallback for custom callback handlers. Delegates to [`ReactiveMP.handle_event`](@ref) and returns the `event`. +Custom handlers should implement `handle_event(handler, event)` rather than `invoke_callback`. +""" +function invoke_callback(handler, event::Event) + handle_event(handler, event) + return event +end + +""" + MergedCallbacks{C}(callbacks) + +The result of the [`ReactiveMP.merge_callbacks`](@ref) procedure. +""" +struct MergedCallbacks{C} + callbacks::C +end + +""" + merge_callbacks(callbacks_handlers...) + +This function accepts an arbitrary amount of callback handlers and merges them together. +Some callback handlers may or may not react on certain types of events. + +```jldoctest +julia> struct PrintEvent <: ReactiveMP.Event{:print_event} + label::String + end; + +julia> handler1 = (print_event = (event) -> println("Handler 1: ", event.label),); + +julia> handler2 = (print_event = (event) -> println("Handler 2: ", event.label),); + +julia> merged_handler = ReactiveMP.merge_callbacks(handler1, handler2); + +julia> ReactiveMP.invoke_callback(merged_handler, PrintEvent("hello")); +Handler 1: hello +Handler 2: hello +``` + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.handle_event`](@ref) +""" +function merge_callbacks(callback_handlers...) + return MergedCallbacks(callback_handlers) +end + +""" + invoke_callback(merged::MergedCallbacks, event::Event) + +A specialized version of [`ReactiveMP.invoke_callback`](@ref) for [`ReactiveMP.MergedCallbacks`](@ref). +Calls the provided callbacks in order. Returns the event after all handlers have been invoked. +""" +function invoke_callback(merged::MergedCallbacks, event::Event) + for callback in merged.callbacks + invoke_callback(callback, event) + end + return event +end + +""" + generate_span_id(callbacks) + +Generates a unique identifier used for "before" and "after" events (see for example [`BeforeMessageRuleCallEvent`](@ref) and [`AfterMessageRuleCallEvent]`](@ref)). If callbacks are not set (e.g. `callbacks` is `nothing`), returns `nothing`. + +The current implementation uses `UUIDs.uuid4` to generate span IDs, but that may change in the future. +""" +function generate_span_id end + +function generate_span_id(::Nothing) + return nothing +end + +function generate_span_id(callbacks) + return uuid4() +end + +# All defined events go here, so its easier to document them all in one place + +""" + BeforeMessageRuleCallEvent{M, Ms, Mr} <: Event{:before_message_rule_call} + +This event fires right before computing the message and calling the corresponding rule. + +# Fields +- `mapping`: of type [`ReactiveMP.MessageMapping`](@ref), contains information about the node type, etc +- `messages`: typically of type `Tuple` if present, `nothing` otherwise +- `marginals`: typically of type `Tuple` if present, `nothing` otherwise +- `span_id`: an id shared with the corresponding [`ReactiveMP.AfterMessageRuleCallEvent`](@ref) + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.AfterMessageRuleCallEvent`](@ref), [`ReactiveMP.generate_span_id`](@ref) +""" +struct BeforeMessageRuleCallEvent{M, Ms, Mr, S} <: + Event{:before_message_rule_call} + mapping::M + messages::Ms + marginals::Mr + span_id::S +end + +""" + AfterMessageRuleCallEvent{M, Ms, Mr, R, A} <: Event{:after_message_rule_call} + +This event fires right after computing the message and calling the corresponding rule. + +# Fields +- `mapping`: of type [`ReactiveMP.MessageMapping`](@ref), contains information about the node type, etc +- `messages`: typically of type `Tuple` if present, `nothing` otherwise +- `marginals`: typically of type `Tuple` if present, `nothing` otherwise +- `result`: the result of the rule invocation (or `rulefallback`), can be any type +- `annotations`: the annotations attached to the result, of type [`ReactiveMP.AnnotationDict`](@ref) +- `span_id`: an id shared with the corresponding [`ReactiveMP.BeforeMessageRuleCallEvent`](@ref) + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.BeforeMessageRuleCallEvent`](@ref), [`ReactiveMP.generate_span_id`](@ref) +""" +struct AfterMessageRuleCallEvent{M, Ms, Mr, R, A, S} <: + Event{:after_message_rule_call} + mapping::M + messages::Ms + marginals::Mr + result::R + annotations::A + span_id::S +end + +""" + BeforeProductOfTwoMessagesEvent{V, C, L, R} <: Event{:before_product_of_two_messages} + +This event fires right before computing the product of two messages. + +# Fields +- `variable`: of type [`ReactiveMP.AbstractVariable`](@ref) +- `context`: of type [`ReactiveMP.MessageProductContext`](@ref) +- `left`: of type [`ReactiveMP.Message`](@ref), the left-hand side message in the product +- `right`: of type [`ReactiveMP.Message`](@ref), the right-hand side message in the product +- `span_id`: an id shared with the corresponding [`ReactiveMP.AfterProductOfTwoMessagesEvent`](@ref) + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.AfterProductOfTwoMessagesEvent`](@ref), [`ReactiveMP.generate_span_id`](@ref) +""" +struct BeforeProductOfTwoMessagesEvent{V, C, L, R, S} <: + Event{:before_product_of_two_messages} + variable::V + context::C + left::L + right::R + span_id::S +end + +""" + AfterProductOfTwoMessagesEvent{V, C, L, R, Rs, A} <: Event{:after_product_of_two_messages} + +This event fires right after computing the product of two messages. + +# Fields +- `variable`: of type [`ReactiveMP.AbstractVariable`](@ref) +- `context`: of type [`ReactiveMP.MessageProductContext`](@ref) +- `left`: of type [`ReactiveMP.Message`](@ref), the left-hand side message in the product +- `right`: of type [`ReactiveMP.Message`](@ref), the right-hand side message in the product +- `result`: of type [`ReactiveMP.Message`](@ref), the resulting message from the product +- `annotations`: the annotations attached to the result, of type [`ReactiveMP.AnnotationDict`](@ref) +- `span_id`: an id shared with the corresponding [`ReactiveMP.BeforeProductOfTwoMessagesEvent`](@ref) + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.BeforeProductOfTwoMessagesEvent`](@ref), [`ReactiveMP.generate_span_id`](@ref) +""" +struct AfterProductOfTwoMessagesEvent{V, C, L, R, Rs, A, S} <: + Event{:after_product_of_two_messages} + variable::V + context::C + left::L + right::R + result::Rs + annotations::A + span_id::S +end + +""" + BeforeProductOfMessagesEvent{V, C, Ms} <: Event{:before_product_of_messages} + +This event fires right before computing the product of a collection of messages +(i.e. at the beginning of [`ReactiveMP.compute_product_of_messages`](@ref)). + +# Fields +- `variable`: of type [`ReactiveMP.AbstractVariable`](@ref) +- `context`: of type [`ReactiveMP.MessageProductContext`](@ref) +- `messages`: the collection of messages to be multiplied +- `span_id`: an id shared with the corresponding [`ReactiveMP.AfterProductOfMessagesEvent`](@ref) + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.AfterProductOfMessagesEvent`](@ref), [`ReactiveMP.generate_span_id`](@ref) +""" +struct BeforeProductOfMessagesEvent{V, C, Ms, S} <: + Event{:before_product_of_messages} + variable::V + context::C + messages::Ms + span_id::S +end + +""" + AfterProductOfMessagesEvent{V, C, Ms, R} <: Event{:after_product_of_messages} + +This event fires right after computing the product of a collection of messages +(i.e. at the end of [`ReactiveMP.compute_product_of_messages`](@ref)). + +# Fields +- `variable`: of type [`ReactiveMP.AbstractVariable`](@ref) +- `context`: of type [`ReactiveMP.MessageProductContext`](@ref) +- `messages`: the original collection of messages that were multiplied +- `result`: of type [`ReactiveMP.Message`](@ref), the final result after folding and form constraint application +- `span_id`: an id shared with the corresponding [`ReactiveMP.BeforeProductOfMessagesEvent`](@ref) + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.BeforeProductOfMessagesEvent`](@ref), [`ReactiveMP.generate_span_id`](@ref) +""" +struct AfterProductOfMessagesEvent{V, C, Ms, R, S} <: + Event{:after_product_of_messages} + variable::V + context::C + messages::Ms + result::R + span_id::S +end + +""" + BeforeFormConstraintAppliedEvent{V, C, S, D} <: Event{:before_form_constraint_applied} + +This event fires right before applying the form constraint via [`ReactiveMP.constrain_form`](@ref). +Fires in both [`ReactiveMP.FormConstraintCheckEach`](@ref) and [`ReactiveMP.FormConstraintCheckLast`](@ref) strategies. + +# Fields +- `variable`: of type [`ReactiveMP.AbstractVariable`](@ref) +- `context`: of type [`ReactiveMP.MessageProductContext`](@ref) +- `strategy`: the form constraint check strategy being used (e.g. [`ReactiveMP.FormConstraintCheckEach`](@ref) or [`ReactiveMP.FormConstraintCheckLast`](@ref)) +- `distribution`: the distribution about to be constrained +- `span_id`: an id shared with the corresponding [`ReactiveMP.AfterFormConstraintAppliedEvent`](@ref) + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.AfterFormConstraintAppliedEvent`](@ref), [`ReactiveMP.generate_span_id`](@ref) +""" +struct BeforeFormConstraintAppliedEvent{V, C, S, D, I} <: + Event{:before_form_constraint_applied} + variable::V + context::C + strategy::S + distribution::D + span_id::I +end + +""" + AfterFormConstraintAppliedEvent{V, C, S, D, R} <: Event{:after_form_constraint_applied} + +This event fires right after applying the form constraint via [`ReactiveMP.constrain_form`](@ref). +Fires in both [`ReactiveMP.FormConstraintCheckEach`](@ref) and [`ReactiveMP.FormConstraintCheckLast`](@ref) strategies. + +# Fields +- `variable`: of type [`ReactiveMP.AbstractVariable`](@ref) +- `context`: of type [`ReactiveMP.MessageProductContext`](@ref) +- `strategy`: the form constraint check strategy being used (e.g. [`ReactiveMP.FormConstraintCheckEach`](@ref) or [`ReactiveMP.FormConstraintCheckLast`](@ref)) +- `distribution`: the distribution before the constraint was applied +- `result`: the distribution after the constraint was applied +- `span_id`: an id shared with the corresponding [`ReactiveMP.BeforeFormConstraintAppliedEvent`](@ref) + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.BeforeFormConstraintAppliedEvent`](@ref), [`ReactiveMP.generate_span_id`](@ref) +""" +struct AfterFormConstraintAppliedEvent{V, C, S, D, R, I} <: + Event{:after_form_constraint_applied} + variable::V + context::C + strategy::S + distribution::D + result::R + span_id::I +end + +""" + BeforeMarginalComputationEvent{V, C, Ms} <: Event{:before_marginal_computation} + +This event fires right before computing the marginal for a [`ReactiveMP.RandomVariable`](@ref) from its incoming messages. + +# Fields +- `variable`: of type [`ReactiveMP.RandomVariable`](@ref) +- `context`: of type [`ReactiveMP.MessageProductContext`](@ref) +- `messages`: the collection of incoming messages used to compute the marginal +- `span_id`: an id shared with the corresponding [`ReactiveMP.AfterMarginalComputationEvent`](@ref) + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.AfterMarginalComputationEvent`](@ref), [`ReactiveMP.generate_span_id`](@ref) +""" +struct BeforeMarginalComputationEvent{V, C, Ms, S} <: + Event{:before_marginal_computation} + variable::V + context::C + messages::Ms + span_id::S +end + +""" + AfterMarginalComputationEvent{V, C, Ms, R} <: Event{:after_marginal_computation} + +This event fires right after computing the marginal for a [`ReactiveMP.RandomVariable`](@ref) from its incoming messages. + +# Fields +- `variable`: of type [`ReactiveMP.RandomVariable`](@ref) +- `context`: of type [`ReactiveMP.MessageProductContext`](@ref) +- `messages`: the collection of incoming messages used to compute the marginal +- `result`: the computed marginal +- `span_id`: an id shared with the corresponding [`ReactiveMP.BeforeMarginalComputationEvent`](@ref) + +See also: [`ReactiveMP.invoke_callback`](@ref), [`ReactiveMP.BeforeMarginalComputationEvent`](@ref), [`ReactiveMP.generate_span_id`](@ref) +""" +struct AfterMarginalComputationEvent{V, C, Ms, R, S} <: + Event{:after_marginal_computation} + variable::V + context::C + messages::Ms + result::R + span_id::S +end diff --git a/src/constraints/form.jl b/src/constraints/form.jl index df8996442..968d81bb1 100644 --- a/src/constraints/form.jl +++ b/src/constraints/form.jl @@ -9,72 +9,84 @@ using TupleTools import BayesBase: resolve_prod_strategy import Base: + -# Form constraints are preserved during execution of the `prod` function -# There are two major strategies to check current functional form -# We may check and preserve functional form of the result of the `prod` function -# after each subsequent `prod` -# or we may want to wait after all `prod` functions in the equality chain have been executed +# Form constraints control the functional form of messages during the product computation. +# There are two strategies for when to apply the constraint: +# - `FormConstraintCheckEach`: apply after each pairwise `prod` in `compute_product_of_two_messages` +# - `FormConstraintCheckLast`: apply once at the end in `compute_product_of_messages` """ AbstractFormConstraint -Every functional form constraint is a subtype of `AbstractFormConstraint` abstract type. +Abstract supertype for all form constraints. Subtype this to create custom form constraints +that can be used with [`constrain_form`](@ref) and [`ReactiveMP.MessageProductContext`](@ref). -Note: this is not strictly necessary, but it makes automatic dispatch easier and compatible with the `CompositeFormConstraint`. +Not strictly required (any object works via [`ReactiveMP.WrappedFormConstraint`](@ref)), +but makes dispatch easier and is needed for [`ReactiveMP.CompositeFormConstraint`](@ref) composition via `+`. """ abstract type AbstractFormConstraint end """ FormConstraintCheckEach -This form constraint check strategy checks functional form of the messages product after each product in an equality chain. -Usually if a variable has been connected to multiple nodes we want to perform multiple `prod` to obtain a posterior marginal. -With this form check strategy `constrain_form` function will be executed after each subsequent `prod` function. +Form constraint check strategy that applies [`constrain_form`](@ref) after **each** pairwise product +inside [`ReactiveMP.compute_product_of_two_messages`](@ref). Use this when intermediate results +need to stay in a specific functional form (e.g. to prevent numerical issues during long product chains). + +See also: [`FormConstraintCheckLast`](@ref), [`ReactiveMP.MessageProductContext`](@ref) """ struct FormConstraintCheckEach end """ - FormConstraintCheckEach + FormConstraintCheckLast -This form constraint check strategy checks functional form of the last messages product in the equality chain. -Usually if a variable has been connected to multiple nodes we want to perform multiple `prod` to obtain a posterior marginal. -With this form check strategy `constrain_form` function will be executed only once after all subsequenct `prod` functions have been executed. +Form constraint check strategy that applies [`constrain_form`](@ref) **once** at the very end +of [`ReactiveMP.compute_product_of_messages`](@ref), after all pairwise products have been folded. +This is the default strategy and is more efficient when intermediate form doesn't matter. + +See also: [`FormConstraintCheckEach`](@ref), [`ReactiveMP.MessageProductContext`](@ref) """ struct FormConstraintCheckLast end """ FormConstraintCheckPickDefault -This form constraint check strategy simply fallbacks to a default check strategy for a given form constraint. +A meta-strategy that defers to the default check strategy of the given form constraint, +as defined by [`default_form_check_strategy`](@ref). """ struct FormConstraintCheckPickDefault end """ default_form_check_strategy(form_constraint) -Returns a default check strategy (e.g. `FormConstraintCheckEach` or `FormConstraintCheckEach`) for a given form constraint object. +Returns the default check strategy (either [`FormConstraintCheckEach`](@ref) or [`FormConstraintCheckLast`](@ref)) +for a given form constraint. Override this for custom constraints to control when they are applied. """ function default_form_check_strategy end """ default_prod_constraint(form_constraint) -Returns a default prod constraint needed to apply a given `form_constraint`. For most form constraints this function returns `ProdGeneric`. +Returns the default product strategy needed to apply a given `form_constraint`. +For most form constraints this returns `BayesBase.GenericProd()`. """ function default_prod_constraint end """ - constrain_form(constraint, something) + constrain_form(constraint, distribution) + +Applies the form `constraint` to `distribution` and returns the constrained result. +This is the main extension point for custom form constraints — implement a method of this function +for your constraint type and the distribution types you want to support. -This function applies a given form constraint to a given object. +See also: [`AbstractFormConstraint`](@ref), [`ReactiveMP.MessageProductContext`](@ref) """ function constrain_form end """ UnspecifiedFormConstraint -One of the form constraint objects. Does not imply any form constraints and simply returns the same object as receives. -However it does not allow `DistProduct` to be a valid functional form in the inference backend. +The default form constraint — does nothing and returns the distribution as-is. +Used when no form constraint has been specified in the [`ReactiveMP.MessageProductContext`](@ref). """ struct UnspecifiedFormConstraint <: AbstractFormConstraint end @@ -87,9 +99,10 @@ constrain_form(::UnspecifiedFormConstraint, something) = something """ WrappedFormConstraint(constraint, context) -This is a wrapper for a form constraint object. It allows to pass additional context to the `constrain_form` function. -By default all objects that are not sub-typed from `AbstractFormConstraint` are wrapped into this object. -Use `ReactiveMP.prepare_context` to provide an extra context for a given form constraint, that can be reused between multiple `constrain_form` calls. +A wrapper that pairs a form constraint with an optional precomputed context. +Any object that is not a subtype of [`AbstractFormConstraint`](@ref) gets automatically wrapped +into this during [`ReactiveMP.preprocess_form_constraints`](@ref). +Use [`ReactiveMP.prepare_context`](@ref) to provide extra context that can be reused across multiple [`constrain_form`](@ref) calls. """ struct WrappedFormConstraint{C, X} <: AbstractFormConstraint constraint::C @@ -101,15 +114,16 @@ struct WrappedFormConstraintNoContext end """ prepare_context(constraint) -This function prepares a context for a given form constraint. Returns `WrappedFormConstraintNoContext` if no context is needed (the default behaviour). +Prepares a reusable context for a given form constraint. Returns `WrappedFormConstraintNoContext` by default (i.e. no context needed). +Override this to precompute things that should be shared across multiple [`constrain_form`](@ref) calls. """ prepare_context(constraint) = WrappedFormConstraintNoContext() """ constrain_form(wrapped::WrappedFormConstraint, something) -This function unwraps the `wrapped` object and calls `constrain_form` function with the provided context. -If the context is not provided, simply calls `constrain_form` with the wrapped constraint. Otherwise passes the context to the `constrain_form` function as the second argument. +Unwraps the constraint and delegates to [`constrain_form`](@ref) with the inner constraint. +If a context was provided via [`ReactiveMP.prepare_context`](@ref), it is passed as the second argument. """ constrain_form(wrapped::WrappedFormConstraint, something) = constrain_form( wrapped, wrapped.context, something @@ -131,8 +145,9 @@ default_prod_constraint(wrapped::WrappedFormConstraint) = default_prod_constrain """ preprocess_form_constraints(constraints) -This function preprocesses form constraints and converts the provided objects into a form compatible with ReactiveMP inference backend (if possible). -If a tuple of constraints is passed, it creates a `CompositeFormConstraint` object. Wraps unknown form constraints into a `WrappedFormConstraint` object. +Converts form constraints into a form compatible with the ReactiveMP inference backend. +A tuple of constraints becomes a [`ReactiveMP.CompositeFormConstraint`](@ref). +Objects that are not subtypes of [`AbstractFormConstraint`](@ref) get wrapped into a [`ReactiveMP.WrappedFormConstraint`](@ref). """ function preprocess_form_constraints end @@ -147,7 +162,9 @@ preprocess_form_constraints(constraint) = WrappedFormConstraint( """ CompositeFormConstraint -Creates a composite form constraint that applies form constraints in order. The composed form constraints must be compatible and have the exact same `form_check_strategy`. +A form constraint that chains multiple constraints together, applying them in order via [`constrain_form`](@ref). +Create one by combining constraints with `+` (e.g. `constraint_a + constraint_b`). +All composed constraints must share the same [`default_form_check_strategy`](@ref). """ struct CompositeFormConstraint{C} <: AbstractFormConstraint constraints::C diff --git a/src/marginal.jl b/src/marginal.jl index 71099de4d..8e95e56b1 100644 --- a/src/marginal.jl +++ b/src/marginal.jl @@ -1,5 +1,4 @@ export Marginal, getdata, is_clamped, is_initial, as_marginal -export SkipClamped, SkipInitial, SkipClampedAndInitial, IncludeAll using Distributions using Rocket @@ -8,7 +7,7 @@ import Rocket: getrecent import Base: ==, ndims, precision, length, size, iterate """ - Marginal(data, is_clamped, is_initial, addons) + Marginal(data, is_clamped, is_initial[, annotations]) An implementation of a marginal in variational message passing framework. @@ -16,18 +15,18 @@ An implementation of a marginal in variational message passing framework. - `data::D`: marginal always holds some data object associated with it, which is usually a probability distribution - `is_clamped::Bool`, specifies if this marginal was the result of constant computations (e.g. clamped constants) - `is_initial::Bool`, specifies if this marginal was used for initialization -- `addons::A`, specifies the addons of the marginal, which may carry extra bits of information, e.g. debug information, memory, etc. +- `annotations::AnnotationDict`: optional annotation dictionary carrying extra metadata (e.g. log-scale, input arguments). Defaults to an empty `AnnotationDict()`. -# Example +# Example ```jldoctest julia> distribution = Gamma(10.0, 2.0) Distributions.Gamma{Float64}(α=10.0, θ=2.0) -julia> message = Marginal(distribution, false, true, nothing) +julia> message = Marginal(distribution, false, true) Marginal(Distributions.Gamma{Float64}(α=10.0, θ=2.0)) -julia> mean(message) +julia> mean(message) 20.0 julia> getdata(message) @@ -40,27 +39,31 @@ julia> is_initial(message) true ``` """ -mutable struct Marginal{D, A} # `mutable` structure here appears to be more performance - const data :: D # in `RxInfer` benchmarks - const is_clamped :: Bool # could be revised at some point though - const is_initial :: Bool - const addons :: A +mutable struct Marginal{D} # `mutable` structure here appears to be more performance + const data :: D # in `RxInfer` benchmarks + const is_clamped :: Bool # could be revised at some point though + const is_initial :: Bool + const annotations :: AnnotationDict end +Marginal(data, is_clamped::Bool, is_initial::Bool) = Marginal( + data, is_clamped, is_initial, AnnotationDict() +) + function Base.show(io::IO, marginal::Marginal) - print(io, string("Marginal(", getdata(marginal), ")")) - if !isnothing(getaddons(marginal)) - print(io, ") with ", string(getaddons(marginal))) + print(io, "Marginal(", getdata(marginal), ")") + ann = getannotations(marginal) + if !isempty(ann) + print(io, " with ", ann) end end function Base.:(==)(left::Marginal, right::Marginal) - # We need this dummy method as Julia is not smart enough to + # We need this dummy method as Julia is not smart enough to # do that automatically if `data` is mutable return left.is_clamped == right.is_clamped && left.is_initial == right.is_initial && - left.data == right.data && - left.addons == right.addons + left.data == right.data end """ @@ -89,11 +92,11 @@ See also: [`is_clamped`](@ref) is_initial(marginal::Marginal) = marginal.is_initial """ - getaddons(marginal::Marginal) + getannotations(marginal::Marginal) -Returns `addons` associated with the `marginal`. +Returns the [`AnnotationDict`](@ref) associated with the `marginal`. """ -getaddons(marginal::Marginal) = marginal.addons +getannotations(marginal::Marginal) = marginal.annotations typeofdata(marginal::Marginal) = typeof(getdata(marginal)) @@ -160,25 +163,25 @@ as_marginal(marginal::Marginal) = marginal dropproxytype(::Type{<:Marginal{T}}) where {T} = T +skip_initial() = filter(v -> !is_initial(v)) +skip_clamped() = filter(v -> !is_clamped(v)) +skip_clamped_and_initial() = filter(v -> !is_initial(v) && !is_clamped(v)) + ## Marginal observable -abstract type MarginalSkipStrategy end +""" + ReactiveMP.MarginalObservable -struct SkipClamped <: MarginalSkipStrategy end -struct SkipInitial <: MarginalSkipStrategy end -struct SkipClampedAndInitial <: MarginalSkipStrategy end -struct IncludeAll <: MarginalSkipStrategy end +A lazy, connectable reactive stream for [`Marginal`](@ref) values, used as the marginal stream of every variable in the factor graph. -Base.broadcastable(::SkipClamped) = Ref(SkipClamped()) -Base.broadcastable(::SkipInitial) = Ref(SkipInitial()) -Base.broadcastable(::SkipClampedAndInitial) = Ref(SkipClampedAndInitial()) -Base.broadcastable(::IncludeAll) = Ref(IncludeAll()) +Internally combines two Rocket.jl primitives: +- a `RecentSubject{Marginal}` that caches the most recently emitted value, so `Rocket.getrecent` always returns the latest belief and late subscribers receive it immediately +- a `LazyObservable{Marginal}` that is the actual subscription target — initially unconnected, and wired to an upstream source during graph activation via `ReactiveMP.connect!` -apply_skip_filter(observable, ::SkipClamped) = observable |> filter(v -> !is_clamped(v)) -apply_skip_filter(observable, ::SkipInitial) = observable |> filter(v -> !is_initial(v)) -apply_skip_filter(observable, ::SkipClampedAndInitial) = observable |> filter(v -> !is_initial(v) && !is_clamped(v)) -apply_skip_filter(observable, ::IncludeAll) = observable +`connect!(observable, source)` sets the lazy stream to `source |> multicast(subject) |> ref_count()`: all subscribers share one upstream subscription, and every emission is forwarded through the cached subject. Before the upstream is connected, [`ReactiveMP.set_initial_marginal!`](@ref) can push an initial belief directly into the subject to seed the graph before inference begins. +See also: [`ReactiveMP.MessageObservable`](@ref), [`ReactiveMP.get_stream_of_marginals`](@ref), [`ReactiveMP.set_initial_marginal!`](@ref) +""" struct MarginalObservable <: Subscribable{Marginal} subject :: Rocket.RecentSubjectInstance{Marginal, Subject{Marginal, AsapScheduler, AsapScheduler}} stream :: LazyObservable{Marginal} @@ -188,15 +191,6 @@ MarginalObservable() = MarginalObservable( RecentSubject(Marginal), lazy(Marginal) ) -as_marginal_observable(observable::MarginalObservable, skip_strategy::MarginalSkipStrategy) = apply_skip_filter(observable, skip_strategy) -as_marginal_observable(observable) = as_marginal_observable(observable, IncludeAll()) - -function as_marginal_observable(observable, skip_strategy::MarginalSkipStrategy) - output = MarginalObservable() - connect!(output, observable) - return as_marginal_observable(output, skip_strategy) -end - Rocket.getrecent(observable::MarginalObservable) = Rocket.getrecent( observable.subject ) @@ -221,8 +215,8 @@ function connect!(marginal::MarginalObservable, source) return nothing end -function setmarginal!(marginal::MarginalObservable, value) - next!(marginal.subject, Marginal(value, false, true, nothing)) +function set_initial_marginal!(marginal::MarginalObservable, value) + next!(marginal.subject, Marginal(value, false, true)) return nothing end @@ -299,7 +293,7 @@ function (mapping::MarginalMapping)(dependencies) ) end - return Marginal(marginal, is_marginal_clamped, is_marginal_initial, nothing) + return Marginal(marginal, is_marginal_clamped, is_marginal_initial) end Base.map(::Type{T}, mapping::M) where {T, M <: MarginalMapping} = Rocket.MapOperator{ diff --git a/src/message.jl b/src/message.jl index 468815e75..989e279b3 100644 --- a/src/message.jl +++ b/src/message.jl @@ -6,6 +6,7 @@ using Rocket import Rocket: getrecent import Base: ==, *, +, ndims, precision, length, size, show +import BayesBase: prod """ An abstract supertype for all concrete message types. @@ -13,7 +14,7 @@ An abstract supertype for all concrete message types. abstract type AbstractMessage end """ - Message(data, is_clamped, is_initial, addons) + Message(data, is_clamped, is_initial[, annotations]) An implementation of a message in variational message passing framework. @@ -21,18 +22,18 @@ An implementation of a message in variational message passing framework. - `data::D`: message always holds some data object associated with it, which is usually a probability distribution, but can also be an arbitrary function - `is_clamped::Bool`, specifies if this message was the result of constant computations (e.g. clamped constants) - `is_initial::Bool`, specifies if this message was used for initialization -- `addons::A`, specifies the addons of the message, which may carry extra bits of information, e.g. debug information, memory, etc. +- `annotations::AnnotationDict`: optional annotation dictionary carrying extra metadata (e.g. log-scale, input arguments). Defaults to an empty `AnnotationDict()`. -# Example +# Example ```jldoctest julia> distribution = Gamma(10.0, 2.0) Distributions.Gamma{Float64}(α=10.0, θ=2.0) -julia> message = Message(distribution, false, true, nothing) +julia> message = Message(distribution, false, true) Message(Distributions.Gamma{Float64}(α=10.0, θ=2.0)) -julia> mean(message) +julia> mean(message) 20.0 julia> getdata(message) @@ -46,13 +47,17 @@ true ``` """ -mutable struct Message{D, A} <: AbstractMessage # `mutable` structure here appears to be more performance - const data :: D # in `RxInfer` benchmarks - const is_clamped :: Bool # could be revised at some point though - const is_initial :: Bool - const addons :: A +mutable struct Message{D} <: AbstractMessage # `mutable` structure here appears to be more performance + const data :: D # in `RxInfer` benchmarks + const is_clamped :: Bool # could be revised at some point though + const is_initial :: Bool + const annotations :: AnnotationDict end +Message(data, is_clamped::Bool, is_initial::Bool) = Message( + data, is_clamped, is_initial, AnnotationDict() +) + """ as_message(::AbstractMessage) @@ -84,22 +89,22 @@ Checks if `message` is initial or not. is_initial(message::Message) = message.is_initial """ - getaddons(message::Message) + getannotations(message::Message) -Returns `addons` associated with the `message`. +Returns the [`AnnotationDict`](@ref) associated with the `message`. """ -getaddons(message::Message) = message.addons +getannotations(message::Message) = message.annotations typeofdata(message::Message) = typeof(getdata(message)) getdata(messages::NTuple{N, <:Message}) where {N} = map(getdata, messages) getdata(messages::AbstractArray{<:Message}) = map(getdata, messages) -# Base.show(io::IO, message::Message) = print(io, string("Message(", getdata(message), ") with ", string(getaddons(message)))) function show(io::IO, message::Message) - print(io, string("Message(", getdata(message), ")")) - if !isnothing(getaddons(message)) - print(io, ") with ", string(getaddons(message))) + print(io, "Message(", getdata(message), ")") + ann = getannotations(message) + if !isempty(ann) + print(io, " with ", ann) end end @@ -108,17 +113,69 @@ end function Base.:(==)(left::Message, right::Message) return left.is_clamped == right.is_clamped && left.is_initial == right.is_initial && - left.data == right.data && - left.addons == right.addons + left.data == right.data end """ - multiply_messages(prod_strategy, left::Message, right::Message) + MessageProductContext(kwargs...) + +The structure that defines the context for the product of **two** messages within ReactiveMP. +The product is executed with the [`ReactiveMP.compute_product_of_messages`](@ref) function and +uses the `BayesBase.prod` under the hood. See BayesBase product API documentation for detailed description. + +The following `kwargs` are supported: +- `prod_constraint`: defines the first argument for the `BayesBase.prod` function (default is `BayesBase.GenericProd`) +- `form_constraint`: defines the form constraint to be applied on the result of computation, default is [`ReactiveMP.UnspecifiedFormConstraint`](@ref) +- `form_constraint_check_strategy`: defines the strategy to check the specified form constraint, either [`ReactiveMP.FormConstraintCheckLast`](@ref) or [`ReactiveMP.FormConstraintCheckEach`](@ref), default is [`ReactiveMP.FormConstraintCheckLast`](@ref) + + [`ReactiveMP.FormConstraintCheckLast`](@ref) will only call [`ReactiveMP.constrain_form`](@ref) at the end of the `[ReactiveMP.compute_product_of_messages]` + + [`ReactiveMP.FormConstraintCheckEach`](@ref) will call [`ReactiveMP.constrain_form`](@ref) at each of the [`ReactiveMP.compute_product_of_two_messages`](@ref) +- `fold_strategy`: defines the strategy (or simply speaking the direction) of the messages product for [`ReactiveMP.compute_product_of_messages`](@ref), default is [`MessagesProductFromLeftToRight`](@ref). Can be a custom function that accepts a `variable`, `context` and collection of `messages` and does arbitrary order, but still needs to call the [`ReactiveMP.compute_product_of_two_messages`](@ref) under the hood (unless you do some experimental stuff). By the way it is called __fold__ to reflect the computer science term with "left-fold" or "right-fold" (and we use the builtin Julia `foldl` and `foldr` functions for that). +- `callbacks`: callbacks handler, see [`ReactiveMP.invoke_callback`](@ref) for more details. + +See also: [`ReactiveMP.compute_product_of_messages`](@ref), [`ReactiveMP.compute_product_of_two_messages`] +""" +Base.@kwdef struct MessageProductContext{C, F, S, L, N, A} + prod_constraint::C = BayesBase.GenericProd() + form_constraint::F = UnspecifiedFormConstraint() + form_constraint_check_strategy::S = FormConstraintCheckLast() + fold_strategy::L = MessagesProductFromLeftToRight() + annotations::N = nothing + callbacks::A = nothing +end -Multiplies two messages `left` and `right` using a given product strategy `prod_strategy`. -Returns a new message with the result of the multiplication. Note that the resulting message is not necessarily normalized. """ -function multiply_messages(prod_strategy, left::Message, right::Message) + compute_product_of_two_messages(variable::AbstractVariable, context::MessageProductContext, left::Message, right::Message) + +Computes the product of two messages `left` and `right` for a given `variable` using the provided `context`. +Returns a new message with the result of the multiplication (not necessarily normalized). +Applies `context.form_constraint` if `context.form_constraint_check_strategy` is set to [`ReactiveMP.FormConstraintCheckEach`](@ref). + +The `variable` argument identifies which variable this product is being computed for, which is useful for callbacks (see [`ReactiveMP.BeforeProductOfTwoMessagesEvent`](@ref)). + +## `is_clamped` and `is_initial` + +The [`ReactiveMP.Message`](@ref) carries the `is_clamped` and `is_initial` flags. +The rules for the product are the following: +- If both messages are clamped, the result is clamped, OR +- If both messages are either clamped or initial, the result is initial, OR +- The result is neither clamped nor initial + +See: [`ReactiveMP.MessageProductContext`](@ref), [`ReactiveMP.compute_product_of_messages`](@ref) +""" +function compute_product_of_two_messages( + variable::AbstractVariable, + context::MessageProductContext, + left::Message, + right::Message, +) + span_id = generate_span_id(context.callbacks) + invoke_callback( + context.callbacks, + BeforeProductOfTwoMessagesEvent( + variable, context, left, right, span_id + ), + ) + # We propagate clamped message, in case if both are clamped is_prod_clamped = is_clamped(left) && is_clamped(right) # We propagate initial message, in case if both are initial or left is initial and right is clameped or vice-versa @@ -130,69 +187,176 @@ function multiply_messages(prod_strategy, left::Message, right::Message) # process distributions left_dist = getdata(left) right_dist = getdata(right) - new_dist = prod(prod_strategy, left_dist, right_dist) + new_dist = prod(context.prod_constraint, left_dist, right_dist) + + if context.form_constraint_check_strategy === FormConstraintCheckEach() + form_span_id = generate_span_id(context.callbacks) + invoke_callback( + context.callbacks, + BeforeFormConstraintAppliedEvent( + variable, + context, + FormConstraintCheckEach(), + new_dist, + form_span_id, + ), + ) + unconstrained_dist = new_dist + new_dist = constrain_form(context.form_constraint, new_dist) + invoke_callback( + context.callbacks, + AfterFormConstraintAppliedEvent( + variable, + context, + FormConstraintCheckEach(), + unconstrained_dist, + new_dist, + form_span_id, + ), + ) + end - # process addons - left_addons = getaddons(left) - right_addons = getaddons(right) + # process annotations + left_ann = getannotations(left) + right_ann = getannotations(right) + new_ann = post_product_annotations!(context.annotations, left_ann, right_ann, new_dist, left_dist, right_dist) + result = Message(new_dist, is_prod_clamped, is_prod_initial, new_ann) - # process addons - new_addons = multiply_addons( - left_addons, right_addons, new_dist, left_dist, right_dist + invoke_callback( + context.callbacks, + AfterProductOfTwoMessagesEvent( + variable, context, left, right, result, new_ann, span_id + ), ) - return Message(new_dist, is_prod_clamped, is_prod_initial, new_addons) + return result end -constrain_form_as_message(message::Message, form_constraint) = Message( - constrain_form(form_constraint, getdata(message)), - is_clamped(message), - is_initial(message), - getaddons(message), +# Sometimes we call the product on the `DeferredMessage` that need to be casted to a `Message` +function compute_product_of_two_messages( + variable::AbstractVariable, context::MessageProductContext, left, right ) + return compute_product_of_two_messages( + variable, context, as_message(left), as_message(right) + ) +end -# Note: we need extra Base.Generator(as_message, messages) step here, because some of the messages might be VMP messages -# We want to cast it explicitly to a Message structure (which as_message does in case of DeferredMessage) -# We use with Base.Generator to reduce an amount of memory used by this procedure since Generator generates items lazily -prod_foldl_reduce(prod_constraint, form_constraint, ::FormConstraintCheckEach) = - (messages) -> foldl( - (left, right) -> constrain_form_as_message( - multiply_messages(prod_constraint, left, right), - form_constraint, - ), - Base.Generator(as_message, messages), +""" + compute_product_of_messages(variable::AbstractVariable, context::MessageProductContext, messages) + +Computes the product of a **collection** of messages for a given `variable` (as opposed to [`ReactiveMP.compute_product_of_two_messages`](@ref), which handles exactly **two** messages). Uses `context.fold_strategy` to determine the order in which [`ReactiveMP.compute_product_of_two_messages`](@ref) is called. By default this is [`ReactiveMP.MessagesProductFromLeftToRight`](@ref), but can be set to an arbitrary function that accepts `variable`, `context` and `messages` and which **must** call [`ReactiveMP.compute_product_of_two_messages`](@ref) under the hood. + +See also: [`ReactiveMP.compute_product_of_two_messages`](@ref), [`ReactiveMP.MessagesProductFromLeftToRight`](@ref) +""" +function compute_product_of_messages( + variable::AbstractVariable, context::MessageProductContext, messages +) + span_id = generate_span_id(context.callbacks) + invoke_callback( + context.callbacks, + BeforeProductOfMessagesEvent(variable, context, messages, span_id), ) -prod_foldl_reduce(prod_constraint, form_constraint, ::FormConstraintCheckLast) = - (messages) -> constrain_form_as_message( - foldl( - (left, right) -> - multiply_messages(prod_constraint, left, right), - Base.Generator(as_message, messages), + result = as_message( + compute_product_of_messages( + context.fold_strategy, variable, context, messages ), - form_constraint, ) -prod_foldr_reduce(prod_constraint, form_constraint, ::FormConstraintCheckEach) = - (messages) -> foldr( - (left, right) -> constrain_form_as_message( - multiply_messages(prod_constraint, left, right), - form_constraint, + if context.form_constraint_check_strategy === FormConstraintCheckLast() + dist = getdata(result) + form_span_id = generate_span_id(context.callbacks) + invoke_callback( + context.callbacks, + BeforeFormConstraintAppliedEvent( + variable, context, FormConstraintCheckLast(), dist, form_span_id + ), + ) + constrained_dist = constrain_form(context.form_constraint, dist) + invoke_callback( + context.callbacks, + AfterFormConstraintAppliedEvent( + variable, + context, + FormConstraintCheckLast(), + dist, + constrained_dist, + form_span_id, + ), + ) + result = Message( + constrained_dist, + is_clamped(result), + is_initial(result), + getannotations(result), + ) + end + + invoke_callback( + context.callbacks, + AfterProductOfMessagesEvent( + variable, context, messages, result, span_id ), - Base.Generator(as_message, messages), ) -prod_foldr_reduce(prod_constraint, form_constraint, ::FormConstraintCheckLast) = - (messages) -> constrain_form_as_message( - foldr( - (left, right) -> - multiply_messages(prod_constraint, left, right), - Base.Generator(as_message, messages), - ), - form_constraint, + return result +end + +""" + MessagesProductFromLeftToRight() + +The default fold strategy for [`ReactiveMP.MessageProductContext`](@ref). Computes the product of messages from left to right using `foldl` within [`ReactiveMP.compute_product_of_messages`](@ref). +""" +struct MessagesProductFromLeftToRight end + +function compute_product_of_messages( + ::MessagesProductFromLeftToRight, + variable::AbstractVariable, + context::MessageProductContext, + messages, +) + return foldl( + (left, right) -> + compute_product_of_two_messages(variable, context, left, right), + messages, ) +end -# Base.:*(m1::Message, m2::Message) = multiply_messages(m1, m2) +""" + MessagesProductFromRightToLeft() + +Alternative fold strategy for [`ReactiveMP.MessageProductContext`](@ref). Computes the product of messages from right to left using `foldr` within [`ReactiveMP.compute_product_of_messages`](@ref). +""" +struct MessagesProductFromRightToLeft end + +function compute_product_of_messages( + ::MessagesProductFromRightToLeft, + variable::AbstractVariable, + context::MessageProductContext, + messages, +) + return foldr( + (left, right) -> + compute_product_of_two_messages(variable, context, left, right), + messages, + ) +end + +""" + compute_product_of_messages(f::Function, variable::AbstractVariable, context::MessageProductContext, messages) + +Custom fold strategy for [`ReactiveMP.compute_product_of_messages`](@ref). When `context.fold_strategy` is set to a `Function`, +it will be called with `variable`, `context` and `messages` as arguments. The function must call +[`ReactiveMP.compute_product_of_two_messages`](@ref) under the hood to compute the pairwise products. +""" +function compute_product_of_messages( + f::Function, + variable::AbstractVariable, + context::MessageProductContext, + messages, +) + return f(variable, context, messages) +end Distributions.pdf(message::Message, x) = Distributions.pdf(getdata(message), x) Distributions.logpdf(message::Message, x) = Distributions.logpdf(getdata(message), x) @@ -291,8 +455,23 @@ end dropproxytype(::Type{<:Message{T}}) where {T} = T -## Message observable +## Message observable + +""" + ReactiveMP.MessageObservable{M <: AbstractMessage} + +A lazy, connectable reactive stream for message values of type `M <: AbstractMessage`, used as the per-connection message stream of every variable in the factor graph. + +Internally combines two Rocket.jl primitives: +- a `RecentSubject{M}` that caches the most recently emitted value, so `Rocket.getrecent` always returns the latest message and late subscribers receive it immediately +- a `LazyObservable{M}` that is the actual subscription target — initially unconnected, and wired to an upstream source during graph activation via `ReactiveMP.connect!` + +`connect!(observable, source)` sets the lazy stream to `source |> multicast(subject) |> ref_count()`: all subscribers share one upstream subscription, and every emission is forwarded through the cached subject. Before the upstream is connected, [`ReactiveMP.set_initial_message!`](@ref) can push an initial message directly into the subject to seed the graph before inference begins. +Each variable-to-node connection owns one `MessageObservable`. For [`ReactiveMP.RandomVariable`](@ref) and [`ReactiveMP.DataVariable`](@ref) these are allocated on demand by `ReactiveMP.create_new_stream_of_inbound_messages!`; for [`ReactiveMP.ConstVariable`](@ref) a single shared instance is created at construction time. + +See also: [`ReactiveMP.MarginalObservable`](@ref), [`ReactiveMP.set_initial_message!`](@ref) +""" struct MessageObservable{M <: AbstractMessage} <: Subscribable{M} subject :: Rocket.RecentSubjectInstance{M, Subject{M, AsapScheduler, AsapScheduler}} stream :: LazyObservable{M} @@ -326,8 +505,8 @@ function connect!(message::MessageObservable, source) return nothing end -function setmessage!(message::MessageObservable, value) - next!(message.subject, Message(value, false, true, nothing)) +function set_initial_message!(message::MessageObservable, value) + next!(message.subject, Message(value, false, true)) return nothing end @@ -342,7 +521,7 @@ end A callable structure representing a deferred computation of a message in the variational message passing framework. It stores all contextual information necessary to compute a message later, such as variable tags, constraints, -addons, and the associated factor node. +annotations, and the associated factor node. `MessageMapping` replaces the original lambda-based implementation to improve type stability and inference. When invoked as a function, it computes an @@ -351,70 +530,21 @@ outgoing `Message` from given input messages and marginals using the appropriate See also: [`Message`](@ref), [`DeferredMessage`](@ref) """ -struct MessageMapping{F, T, C, N, M, A, X, R, K} +struct MessageMapping{F, T, C, N, M, A, X, R, K, E} vtag :: T vconstraint :: C msgs_names :: N marginals_names :: M meta :: A - addons :: X + annotations :: X factornode :: R rulefallback :: K + callbacks :: E end message_mapping_fform(::MessageMapping{F}) where {F} = F message_mapping_fform(::MessageMapping{F}) where {F <: Function} = F.instance -# Some addons add post rule execution logic -function message_mapping_addons( - mapping::MessageMapping, messages, marginals, result, addons -) - return message_mapping_addons( - mapping, mapping.addons, messages, marginals, result, addons - ) -end - -# `enabled_addons` are always type-stable, whether `addons` are not, so we check based on the `enabled_addons` and ignore the `addons` -# As a consequence if any message update rule returns non-empty `addons`, but `enabled_addons` is empty, then the resulting value -# of the `addons` will be simply ignored -message_mapping_addons( - mapping::MessageMapping, - enabled_addons::Nothing, - messages, - marginals, - result, - addons, -) = enabled_addons -message_mapping_addons( - mapping::MessageMapping, - enabled_addons::Tuple{}, - messages, - marginals, - result, - addons, -) = enabled_addons - -# The main logic here is that some addons may add extra computation AFTER the rule has been computed -# The benefit of that is that we have an access to the `MessageMapping` structure and is mostly useful for debug addons -function message_mapping_addons( - mapping::MessageMapping, - enabled_addons::Tuple, - messages, - marginals, - result, - addons, -) - return map(addons) do addon - return message_mapping_addon( - addon, mapping, messages, marginals, result - ) - end -end - -# By default `message_mapping_addon` does nothing and simply returns the addon itself -# Other addons may override this behaviour (if necessary, see e.g. AddonMemory) -message_mapping_addon(addon, mapping, messages, marginals, result) = addon - function MessageMapping( ::Type{F}, vtag::T, @@ -422,19 +552,21 @@ function MessageMapping( msgs_names::N, marginals_names::M, meta::A, - addons::X, + annotations::X, factornode::R, rulefallback::K, -) where {F, T, C, N, M, A, X, R, K} - return MessageMapping{F, T, C, N, M, A, X, R, K}( + callbacks::E, +) where {F, T, C, N, M, A, X, R, K, E} + return MessageMapping{F, T, C, N, M, A, X, R, K, E}( vtag, vconstraint, msgs_names, marginals_names, meta, - addons, + annotations, factornode, rulefallback, + callbacks, ) end @@ -445,19 +577,21 @@ function MessageMapping( msgs_names::N, marginals_names::M, meta::A, - addons::X, + annotations::X, factornode::R, rulefallback::K, -) where {F <: Function, T, C, N, M, A, X, R, K} - return MessageMapping{F, T, C, N, M, A, X, R, K}( + callbacks::E, +) where {F <: Function, T, C, N, M, A, X, R, K, E} + return MessageMapping{F, T, C, N, M, A, X, R, K, E}( vtag, vconstraint, msgs_names, marginals_names, meta, - addons, + annotations, factornode, rulefallback, + callbacks, ) end @@ -473,13 +607,28 @@ function (mapping::MessageMapping)(messages, marginals) __check_all(is_clamped_or_initial, marginals) ) - result, addons = + span_id = generate_span_id(mapping.callbacks) + invoke_callback( + mapping.callbacks, + BeforeMessageRuleCallEvent(mapping, messages, marginals, span_id), + ) + + annotations = AnnotationDict() + + # Run annotation processors before the rule has been executed + if !isnothing(mapping.annotations) + for p in mapping.annotations + pre_rule_annotations!(p, annotations, mapping, messages, marginals) + end + end + + result = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages))) - missing, mapping.addons + missing elseif !isnothing(marginals) && any(ismissing, TupleTools.flatten(getdata.(marginals))) - missing, mapping.addons + missing else ruleargs = ( message_mapping_fform(mapping), @@ -490,11 +639,11 @@ function (mapping::MessageMapping)(messages, marginals) mapping.marginals_names, marginals, mapping.meta, - mapping.addons, + annotations, mapping.factornode, ) ruleoutput = rule(ruleargs...) - # if `@rule` is not defined, the default behaviour is to return + # if `@rule` is not defined, the default behaviour is to return # the `RuleMethodError` object if ruleoutput isa RuleMethodError if !isnothing(mapping.rulefallback) @@ -507,10 +656,21 @@ function (mapping::MessageMapping)(messages, marginals) end end - # Inject extra addons after the rule has been executed - addons = message_mapping_addons( - mapping, getdata(messages), getdata(marginals), result, addons + # Run annotation processors after the rule has been executed + if !isnothing(mapping.annotations) + for p in mapping.annotations + post_rule_annotations!( + p, annotations, mapping, messages, marginals, result + ) + end + end + + invoke_callback( + mapping.callbacks, + AfterMessageRuleCallEvent( + mapping, messages, marginals, result, annotations, span_id + ), ) - return Message(result, is_message_clamped, is_message_initial, addons) + return Message(result, is_message_clamped, is_message_initial, annotations) end diff --git a/src/nodes/clusters.jl b/src/nodes/clusters.jl index 3acddd961..bc2b38f27 100644 --- a/src/nodes/clusters.jl +++ b/src/nodes/clusters.jl @@ -16,9 +16,22 @@ end name(localmarginal::FactorNodeLocalMarginal) = localmarginal.name tag(localmarginal::FactorNodeLocalMarginal) = Val{name(localmarginal)}() -getmarginal(localmarginal::FactorNodeLocalMarginal) = localmarginal.marginal -setmarginal!(localmarginal::FactorNodeLocalMarginal, marginal) = +get_stream_of_marginals(localmarginal::FactorNodeLocalMarginal) = + localmarginal.marginal + +function set_stream_of_marginals!( + localmarginal::FactorNodeLocalMarginal, stream::MarginalObservable +) + localmarginal.marginal = stream +end + +function set_stream_of_marginals!( + localmarginal::FactorNodeLocalMarginal, stream +) + marginal = MarginalObservable() + connect!(marginal, stream) localmarginal.marginal = marginal +end Base.show(io::IO, marginal::FactorNodeLocalMarginal) = print( io, "FactorNodeLocalMarginal(", name(marginal), ")" @@ -31,12 +44,9 @@ struct FactorNodeLocalClusters{M, F} factorization::F end -getmarginals(clusters::FactorNodeLocalClusters) = clusters.marginals -getmarginal(clusters::FactorNodeLocalClusters, index) = getindex( - getmarginals(clusters), index -) -setmarginal!(clusters::FactorNodeLocalClusters, index, marginal::MarginalObservable) = setmarginal!( - getmarginal(clusters, index), marginal +get_node_local_marginals(clusters::FactorNodeLocalClusters) = clusters.marginals +set_node_local_marginal_stream!(clusters::FactorNodeLocalClusters, index, stream) = set_stream_of_marginals!( + clusters.marginals[index], stream ) getfactorization(clusters::FactorNodeLocalClusters) = clusters.factorization @@ -86,10 +96,10 @@ function initialize_clusters!( clusters::FactorNodeLocalClusters, dependencies, factornode, options ) # We first need to initialize all the clusters, since the `activate_cluster!` function may use any of the marginals - for i in eachindex(getmarginals(clusters)) + for i in eachindex(get_node_local_marginals(clusters)) initialize_cluster!(clusters, i, dependencies, factornode, options) end - for i in eachindex(getmarginals(clusters)) + for i in eachindex(get_node_local_marginals(clusters)) activate_cluster!(clusters, i, dependencies, factornode, options) end end @@ -105,15 +115,16 @@ function initialize_cluster!( # For the clusters of length `1` there is no need to create a new `MarginalObservable` object # We can simply reuse it from the variable connected to the factor node. Potentially it saves a bit of memory stream_of_cluster_marginals = if isone(length(localfactorization)) - getmarginal( - getvariable(getinterface(factornode, first(localfactorization))), - IncludeAll(), + get_stream_of_marginals( + getvariable(getinterface(factornode, first(localfactorization))) ) else # For the clusters of length `>1` we need to create the new strean, but it will be assigned later MarginalObservable() end - setmarginal!(clusters, index, stream_of_cluster_marginals) + set_node_local_marginal_stream!( + clusters, index, stream_of_cluster_marginals + ) end function activate_cluster!( @@ -124,19 +135,20 @@ function activate_cluster!( options, ) localfactorization = getfactorization(clusters, index) + stream_postprocessors = getpostprocessor(options) if !isone(length(localfactorization)) # For the clusters which length is not equal to one we should collect the dependencies # and call the `MarginalMapping` to compute the result. The `MarginalObservable` should have # been initialized in the `initialize_cluster!` before - marginal = getmarginal(clusters, index) + marginal = get_node_local_marginals(clusters)[index] clusterinterfaces = map( i -> getinterface(factornode, i), localfactorization ) message_dependencies = tuple(clusterinterfaces...) - marginal_dependencies = tuple(TupleTools.deleteat(getmarginals(clusters), index)...) + marginal_dependencies = tuple(TupleTools.deleteat(get_node_local_marginals(clusters), index)...) messagestag, messages = collect_latest_messages( dependencies, factornode, message_dependencies @@ -146,12 +158,13 @@ function activate_cluster!( ) fform = functionalform(factornode) - vtag = tag(getmarginal(clusters, index)) + vtag = tag(get_node_local_marginals(clusters)[index]) meta = collect_meta(fform, getmetadata(options)) mapping = MarginalMapping(fform, vtag, messagestag, marginalstag, meta, node_if_required(fform, factornode)) marginalout = combineLatestUpdates((messages, marginals), PushNew(), Marginal, mapping, reset_vstatus) + marginalout = postprocess_stream_of_marginals(stream_postprocessors, marginalout) - connect!(getmarginal(marginal), marginalout) + set_stream_of_marginals!(marginal, marginalout) end end diff --git a/src/nodes/dependencies.jl b/src/nodes/dependencies.jl index 0da1f77aa..e45435c26 100644 --- a/src/nodes/dependencies.jl +++ b/src/nodes/dependencies.jl @@ -4,10 +4,10 @@ export DefaultFunctionalDependencies, RequireEverythingFunctionalDependencies collect_latest_messages(dependencies, factornode, collection) = __collect_latest_updates( - messagein, collection + get_stream_of_inbound_messages, collection ) collect_latest_marginals(dependencies, factornode, collection) = __collect_latest_updates( - getmarginal, collection + get_stream_of_marginals, collection ) function __collect_latest_updates(f::F, collection) where {F} @@ -25,15 +25,22 @@ function __collect_latest_updates(f::F, collection::Tuple) where {F} end end +""" + ReactiveMP.FunctionalDependencies + +Abstract supertype for policies that determine which messages and marginals are required to compute each outbound message at a factor node. A concrete subtype is passed as `options.dependencies` in [`ReactiveMP.FactorNodeActivationOptions`](@ref) and consulted during [`ReactiveMP.activate!(::FactorNode, ::FactorNodeActivationOptions)`](@ref). + +See also: [`ReactiveMP.DefaultFunctionalDependencies`](@ref), [`ReactiveMP.RequireMessageFunctionalDependencies`](@ref), [`ReactiveMP.RequireMarginalFunctionalDependencies`](@ref), [`ReactiveMP.RequireEverythingFunctionalDependencies`](@ref) +""" abstract type FunctionalDependencies end function activate!(dependencies::FunctionalDependencies, factornode, options) - scheduler = getscheduler(options) - addons = getaddons(options) - rulefallback = getrulefallback(options) - fform = functionalform(factornode) - meta = collect_meta(fform, getmetadata(options)) - pipeline = collect_pipeline(fform, getpipeline(options)) + annotations = getannotations(options) + rulefallback = getrulefallback(options) + callbacks = getcallbacks(options) + fform = functionalform(factornode) + meta = collect_meta(fform, getmetadata(options)) + stream_postprocessor = getpostprocessor(options) foreach(enumerate(getinterfaces(factornode))) do (iindex, interface) if israndom(interface) || isdata(interface) @@ -50,7 +57,9 @@ function activate!(dependencies::FunctionalDependencies, factornode, options) vtag = tag(interface) vconstraint = Marginalisation() - vmessageout = combineLatest((messages, marginals), PushNew()) + stream_of_outbound_messages = combineLatest( + (messages, marginals), PushNew() + ) mapping = let messagemap = MessageMapping( @@ -60,22 +69,24 @@ function activate!(dependencies::FunctionalDependencies, factornode, options) messagestag, marginalstag, meta, - addons, + annotations, node_if_required(fform, factornode), rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap ) end - vmessageout = vmessageout |> map(AbstractMessage, mapping) - vmessageout = apply_pipeline_stage( - pipeline, factornode, vtag, vmessageout + stream_of_outbound_messages = + stream_of_outbound_messages |> map(AbstractMessage, mapping) + stream_of_outbound_messages = postprocess_stream_of_outbound_messages( + stream_postprocessor, stream_of_outbound_messages + ) + set_stream_of_outbound_messages!( + interface, stream_of_outbound_messages ) - vmessageout = vmessageout |> schedule_on(scheduler) - - connect!(messageout(interface), vmessageout) end end end @@ -100,6 +111,13 @@ In order to compute a message out of some interface, this strategy requires mess """ struct DefaultFunctionalDependencies <: FunctionalDependencies end +""" + ReactiveMP.collect_functional_dependencies(fform, dependencies) + +Returns the [`ReactiveMP.FunctionalDependencies`](@ref) instance to use for a factor node with functional form `fform`. +If `dependencies` is `nothing`, falls back to `default_functional_dependencies(fform)`, which returns [`ReactiveMP.DefaultFunctionalDependencies`](@ref) for most nodes. +Otherwise returns `dependencies` unchanged, allowing callers to override the policy per node. +""" function collect_functional_dependencies end collect_functional_dependencies(fform::F, ::Nothing) where {F} = default_functional_dependencies( @@ -125,7 +143,9 @@ function functional_dependencies( ) # For the marginal dependencies we need to skip the current cluster - marginal_dependencies = skipindex(getmarginals(clusters), cindex) + marginal_dependencies = skipindex( + get_node_local_marginals(clusters), cindex + ) return message_dependencies, marginal_dependencies end @@ -181,7 +201,9 @@ function functional_dependencies( initialmessage = specification[name(interface)] # Set the initial message if its not `nothing` if !isnothing(initialmessage) - setmessage!(messagein(interface), initialmessage) + set_initial_message!( + get_stream_of_inbound_messages(interface), initialmessage + ) end # And return the cluster as is cluster @@ -193,7 +215,9 @@ function functional_dependencies( ) # For the marginal dependencies we need to skip the current cluster - marginal_dependencies = skipindex(getmarginals(clusters), cindex) + marginal_dependencies = skipindex( + get_node_local_marginals(clusters), cindex + ) return message_dependencies, marginal_dependencies end @@ -248,7 +272,7 @@ function functional_dependencies( ) # For the marginal dependencies we need to skip the current cluster - marginal_dependencies_default_clusters = skipindex(getmarginals(clusters), cindex) + marginal_dependencies_default_clusters = skipindex(get_node_local_marginals(clusters), cindex) marginal_dependencies_default_factorization = skipindex(getfactorization(clusters), cindex) marginal_dependencies = if name(interface) ∈ keys(specification) @@ -257,14 +281,12 @@ function functional_dependencies( extra_localmarginal = FactorNodeLocalMarginal(name(interface)) # Create a stream of marginals and connect it with the streams of marginals of the actual variable extra_stream = MarginalObservable() - connect!( - extra_stream, getmarginal(getvariable(interface), IncludeAll()) - ) - setmarginal!(extra_localmarginal, extra_stream) + connect!(extra_stream, get_stream_of_marginals(getvariable(interface))) + set_stream_of_marginals!(extra_localmarginal, extra_stream) initialmarginals = specification[name(interface)] if !isnothing(initialmarginals) - setmarginal!(extra_stream, initialmarginals) + set_initial_marginal!(extra_stream, initialmarginals) end insertafter = sum( @@ -287,7 +309,7 @@ end """ RequireEverythingFunctionalDependencies -This pipeline specifies that in order to compute a message of some edge update rules request everything that is available locally. +This strategy specifies that in order to compute a message of some edge update rules request everything that is available locally. This includes all inbound messages (including on the same edge) and marginals over all local edge-clusters (this may or may not include marginals on single edges, depends on the local factorisation constraint). See also: [`DefaultFunctionalDependencies`](@ref), [`RequireMessageFunctionalDependencies`](@ref), [`RequireMarginalFunctionalDependencies`](@ref) @@ -306,7 +328,7 @@ function functional_dependencies( message_dependencies = Iterators.map( inds -> map(i -> getinterface(factornode, i), inds), cluster ) - marginal_dependencies = getmarginals(clusters) + marginal_dependencies = get_node_local_marginals(clusters) return message_dependencies, marginal_dependencies end diff --git a/src/nodes/equality.jl b/src/nodes/equality.jl index 3f374144d..a5f2628eb 100644 --- a/src/nodes/equality.jl +++ b/src/nodes/equality.jl @@ -40,8 +40,8 @@ mutable struct EqualityNode EqualityNode() = new( lazy(Missing), lazy(Missing), - Message(missing, true, true, nothing), - Message(missing, true, true, nothing), + Message(missing, true, true), + Message(missing, true, true), ) end @@ -63,23 +63,23 @@ setcache!(::EqualityRightOutbound, node::EqualityNode, cache::Message) = node.ca EqualityChain """ struct EqualityChain{P, F} - length :: Int - nodes :: Vector{EqualityNode} - inputmsgs :: Vector{MessageObservable{AbstractMessage}} - cacheleft :: BitVector - cacheright :: BitVector - pipeline :: P - prod_fn :: F + length :: Int + nodes :: Vector{EqualityNode} + inputmsgs :: Vector{MessageObservable{AbstractMessage}} + cacheleft :: BitVector + cacheright :: BitVector + postprocessor :: P + prod_fn :: F function EqualityChain( inputmsgs::Vector{MessageObservable{AbstractMessage}}, - pipeline::P, + postprocessor::P, prod_fn::F, ) where {P, F} n = length(inputmsgs) nodes = map(_ -> EqualityNode(), 1:n) return new{P, F}( - n, nodes, inputmsgs, falses(n), falses(n), pipeline, prod_fn + n, nodes, inputmsgs, falses(n), falses(n), postprocessor, prod_fn ) end end @@ -88,7 +88,7 @@ Base.length(chain::EqualityChain) = chain.length prod(chain::EqualityChain, left, right) = chain.prod_fn((left, right)) -getpipeline(chain::EqualityChain) = chain.pipeline +getpostprocessor(chain::EqualityChain) = chain.postprocessor @propagate_inbounds getnode(chain::EqualityChain, node_index) = chain.nodes[node_index] @@ -113,7 +113,7 @@ __check_indices(::EqualityRightOutbound, chain::EqualityChain, node_index) = 1 < if __check_indices(type, chain, node_index) return getcache(type, getnode(chain, node_index)) else - return Message(missing, true, true, nothing) + return Message(missing, true, true) end end @@ -146,7 +146,7 @@ nextindex(::EqualityRightOutbound, node_index) = node_index - 1 return materialize!(type, chain, node_index) end else - return Message(missing, true, true, nothing) + return Message(missing, true, true) end end @@ -205,7 +205,7 @@ Base.map(::Type{Message}, mapping::ChainOutboundMapping) = Rocket.MapOperator{ function initialize!(chain::EqualityChain, outputmsgs::AbstractVector) n = length(chain) - pipeline = getpipeline(chain) + postprocessor = getpostprocessor(chain) Left = EqualityLeftOutbound() Right = EqualityRightOutbound() @@ -219,8 +219,18 @@ function initialize!(chain::EqualityChain, outputmsgs::AbstractVector) tap(ChainInvalidationCallback(index, chain)) |> share_recent() - left = combineLatestUpdates((getoutbound(Left, chain, nextindex(Left, index)), input), PushNew()) |> pipeline |> map_to(missing) |> share_recent() - right = combineLatestUpdates((getoutbound(Right, chain, nextindex(Right, index)), input), PushNew()) |> pipeline |> map_to(missing) |> share_recent() + left = combineLatestUpdates( + (getoutbound(Left, chain, nextindex(Left, index)), input), PushNew() + ) + left = postprocess_stream_of_outbound_messages(postprocessor, left) + left = left |> map_to(missing) |> share_recent() + + right = combineLatestUpdates( + (getoutbound(Right, chain, nextindex(Right, index)), input), + PushNew(), + ) + right = postprocess_stream_of_outbound_messages(postprocessor, right) + right = right |> map_to(missing) |> share_recent() setoutbound!(Left, node, left) setoutbound!(Right, node, right) diff --git a/src/nodes/interfaces.jl b/src/nodes/interfaces.jl index 5e8480049..d4d95f6d2 100644 --- a/src/nodes/interfaces.jl +++ b/src/nodes/interfaces.jl @@ -1,7 +1,13 @@ """ - NodeInterface + ReactiveMP.NodeInterface -`NodeInterface` object represents a single node-variable connection. +Represents a single directed connection between a factor node and an [`ReactiveMP.AbstractVariable`](@ref). + +Each interface owns one [`ReactiveMP.MessageObservable`](@ref) (`m_out`) — the *outbound* message stream from this node toward the connected variable. The constructor immediately calls [`ReactiveMP.create_new_stream_of_inbound_messages!`](@ref) on the variable, which allocates a per-connection slot in the variable's `input_messages` and returns the same observable together with its index. This means `m_out` for the interface is the inbound message stream from the variable's perspective. + +After graph construction the streams are unconnected (lazy). [`ReactiveMP.activate!`](@ref) wires `m_out` to the result of the message update rule via [`ReactiveMP.set_stream_of_outbound_messages!`](@ref). + +See also: [`ReactiveMP.IndexedNodeInterface`](@ref), [`ReactiveMP.get_stream_of_outbound_messages`](@ref), [`ReactiveMP.get_stream_of_inbound_messages`](@ref) """ struct NodeInterface name::Symbol @@ -10,8 +16,8 @@ struct NodeInterface message_index::Int function NodeInterface(name::Symbol, variable::AbstractVariable) - # `messagein` for variable is `m_out` for the interface - m_out, message_index = create_messagein!(variable) + # `inbound message` for variable is `m_out` for the interface + m_out, message_index = create_new_stream_of_inbound_messages!(variable) return new(name, m_out, variable, message_index) end @@ -45,18 +51,28 @@ The major difference between tag and name is that it is possible to dispath on i tag(interface::NodeInterface) = Val{name(interface)}() """ - messageout(interface) + get_stream_of_outbound_messages(interface) Returns an outbound messages stream from the given interface. """ -messageout(interface::NodeInterface) = interface.m_out +get_stream_of_outbound_messages(interface::NodeInterface) = interface.m_out + +""" + ReactiveMP.set_stream_of_outbound_messages!(interface, stream) + +Connects `stream` to the outbound message observable of `interface`. +See also [`ReactiveMP.get_stream_of_outbound_messages`](@ref), [`ReactiveMP.get_stream_of_inbound_messages`](@ref). +""" +set_stream_of_outbound_messages!(interface::NodeInterface, stream) = connect!( + get_stream_of_outbound_messages(interface), stream +) """ - messagein(interface) + get_stream_of_inbound_messages(interface) Returns an inbound messages stream from the given interface. """ -messagein(interface::NodeInterface) = messageout( +get_stream_of_inbound_messages(interface::NodeInterface) = get_stream_of_outbound_messages( interface.variable, interface.message_index ) @@ -68,10 +84,11 @@ Returns a variable connected to the given interface. getvariable(interface::NodeInterface) = interface.variable """ - IndexedNodeInterface + ReactiveMP.IndexedNodeInterface + +A thin wrapper around [`ReactiveMP.NodeInterface`](@ref) that adds a positional `index`, used for nodes with a variable-length list of same-named edges (e.g. the `means` or `precisions` of a Gaussian Mixture node). All stream and variable accessors delegate to the wrapped interface. -`IndexedNodeInterface` object represents a repetative node-variable connection. -Used in cases when a node may connect to a different number of random variables with the same name, e.g. means and precisions of a Gaussian Mixture node. +See also: [`ReactiveMP.NodeInterface`](@ref), [`ReactiveMP.ManyOf`](@ref) """ struct IndexedNodeInterface index :: Int @@ -87,8 +104,15 @@ index(interface::IndexedNodeInterface) = interface.index name(interface::IndexedNodeInterface) = name(interface.interface) tag(interface::IndexedNodeInterface) = (tag(interface.interface), index(interface)) -messageout(interface::IndexedNodeInterface) = messageout(interface.interface) -messagein(interface::IndexedNodeInterface) = messagein(interface.interface) +get_stream_of_outbound_messages(interface::IndexedNodeInterface) = get_stream_of_outbound_messages( + interface.interface +) +set_stream_of_outbound_messages!(interface::IndexedNodeInterface, stream) = set_stream_of_outbound_messages!( + interface.interface, stream +) +get_stream_of_inbound_messages(interface::IndexedNodeInterface) = get_stream_of_inbound_messages( + interface.interface +) getvariable(interface::IndexedNodeInterface) = getvariable(interface.interface) israndom(interface::IndexedNodeInterface) = israndom(interface.interface) @@ -161,6 +185,8 @@ function combineLatestMessagesInUpdates( indexed::NTuple{N, <:IndexedNodeInterface} ) where {N} return ManyOfObservable( - combineLatestUpdates(map((in) -> messagein(in), indexed), PushNew()) + combineLatestUpdates( + map((in) -> get_stream_of_inbound_messages(in), indexed), PushNew() + ), ) end diff --git a/src/nodes/nodes.jl b/src/nodes/nodes.jl index 8346cbd23..c9fdb6ff7 100644 --- a/src/nodes/nodes.jl +++ b/src/nodes/nodes.jl @@ -9,8 +9,6 @@ using Rocket using TupleTools using MacroTools -import Rocket: getscheduler - import Base: show, +, push!, iterate, IteratorSize, IteratorEltype, eltype, length, size import Base: getindex, setindex!, firstindex, lastindex @@ -274,27 +272,61 @@ function prepare_interfaces_check_num_inputarguments( ) end -struct FactorNodeActivationOptions{M, D, P, A, S, R} +""" + ReactiveMP.FactorNodeActivationOptions + +Collects all configuration needed to activate a [`FactorNode`](@ref). Passed to [`ReactiveMP.activate!(::FactorNode, ::FactorNodeActivationOptions)`](@ref). + +Fields: +- `metadata` — node-specific metadata forwarded to message update rules (see [`ReactiveMP.collect_meta`](@ref)) +- `dependencies` — a [`ReactiveMP.FunctionalDependencies`](@ref) policy that determines which messages and marginals each outbound message computation depends on (default: [`ReactiveMP.DefaultFunctionalDependencies`](@ref)) +- `postprocessor` — optional stream postprocessor applied to every created stream (see [`ReactiveMP.AbstractStreamPostprocessor`](@ref)) +- `annotations` — optional annotation processors (see [`ReactiveMP.AbstractAnnotations`](@ref)) +- `rulefallback` — optional fallback called when no `@rule` method matches +- `callbacks` — optional callbacks invoked at key points in the message computation (see [`ReactiveMP.invoke_callback`](@ref)) +""" +struct FactorNodeActivationOptions{M, D, P, A, R, E} metadata::M dependencies::D - pipeline::P - addons::A - scheduler::S + postprocessor::P + annotations::A rulefallback::R + callbacks::E end getmetadata(options::FactorNodeActivationOptions) = options.metadata getdependecies(options::FactorNodeActivationOptions) = options.dependencies -getpipeline(options::FactorNodeActivationOptions) = options.pipeline -getaddons(options::FactorNodeActivationOptions) = options.addons -getscheduler(options::FactorNodeActivationOptions) = options.scheduler +getpostprocessor(options::FactorNodeActivationOptions) = options.postprocessor +getannotations(options::FactorNodeActivationOptions) = options.annotations getrulefallback(options::FactorNodeActivationOptions) = options.rulefallback +getcallbacks(options::FactorNodeActivationOptions) = options.callbacks # Users can override the dependencies if they want to collect_functional_dependencies(fform::F, options::FactorNodeActivationOptions) where {F} = collect_functional_dependencies( fform, getdependecies(options) ) +""" + ReactiveMP.activate!(factornode::FactorNode, options::FactorNodeActivationOptions) + +Wires all reactive message-passing streams of a [`FactorNode`](@ref) into the factor graph. + +Activation proceeds in three phases: + +1. **Collect functional dependencies** — calls [`ReactiveMP.collect_functional_dependencies`](@ref) to determine, for each interface, which inbound messages and local marginals are needed to compute the outbound message. The policy is controlled by `options.dependencies` (default: [`ReactiveMP.DefaultFunctionalDependencies`](@ref)). + +2. **Initialize clusters** — sets up a [`ReactiveMP.FactorNodeLocalMarginal`](@ref) stream for each cluster in the factorization. For mean-field each cluster contains one variable and its local marginal is shared directly with the variable's marginal stream. For structured factorizations a joint marginal stream is created from the cluster's message dependencies. + +3. **Wire outbound message streams** — for every interface connected to a [`ReactiveMP.RandomVariable`](@ref) or [`ReactiveMP.DataVariable`](@ref): + - combines the required inbound message and marginal streams with `combineLatest` + - wraps the rule application in a [`ReactiveMP.MessageMapping`](@ref), producing a [`ReactiveMP.DeferredMessage`](@ref) on each upstream update + - applies any `options.postprocessor` transformations + - connects the result to the interface's [`ReactiveMP.MessageObservable`](@ref) via [`ReactiveMP.set_stream_of_outbound_messages!`](@ref) + +Interfaces connected to [`ReactiveMP.ConstVariable`](@ref) are skipped: their message is fixed at graph construction. + +See also: [`ReactiveMP.FactorNodeActivationOptions`](@ref), [`ReactiveMP.activate!(::RandomVariable, ::RandomVariableActivationOptions)`](@ref) +""" function activate!(factornode::FactorNode, options::FactorNodeActivationOptions) dependencies = collect_functional_dependencies( functionalform(factornode), options diff --git a/src/nodes/predefined/bifm_helper.jl b/src/nodes/predefined/bifm_helper.jl index 45e1b75aa..3c99ce8ed 100644 --- a/src/nodes/predefined/bifm_helper.jl +++ b/src/nodes/predefined/bifm_helper.jl @@ -31,7 +31,7 @@ function functional_dependencies( cindex = clusterindex(clusters, iindex) nodeinterfaces = getinterfaces(factornode) - nodelocalmarginals = getmarginals(clusters) + nodelocalmarginals = get_node_local_marginals(clusters) # output if iindex === 2 diff --git a/src/nodes/predefined/delta/delta.jl b/src/nodes/predefined/delta/delta.jl index b664bcca9..fa0eaca64 100644 --- a/src/nodes/predefined/delta/delta.jl +++ b/src/nodes/predefined/delta/delta.jl @@ -88,7 +88,7 @@ function rule( qnames, marginals, meta::DeltaMeta, - addons::Any, + annotations::Any, node::DeltaFnNode, ) where {F <: Function} return rule( @@ -100,7 +100,7 @@ function rule( qnames, marginals, meta, - addons, + annotations, node, ) end @@ -241,7 +241,7 @@ end # For datavar we get the latest value from the data stream __unpack_latest_static(_, constvar::ConstVariable) = getconst(constvar) __unpack_latest_static(_, datavar::DataVariable) = BayesBase.getpointmass( - getdata(Rocket.getrecent(messageout(datavar, 1))) + getdata(Rocket.getrecent(get_stream_of_outbound_messages(datavar, 1))) ) # By default all `meta` objects fallback to the `DeltaFnDefaultRuleLayout` @@ -268,9 +268,7 @@ end function activate!(factornode::DeltaFnNode, options) meta = collect_meta(functionalform(factornode), getmetadata(options)) - pipeline = collect_pipeline( - functionalform(factornode), getpipeline(options) - ) + stream_postprocessor = getpostprocessor(options) if !isnothing(getinverse(meta)) && !isempty(factornode.statics) error( @@ -284,7 +282,7 @@ function activate!(factornode::DeltaFnNode, options) factornode, deltafn_rule_layout(factornode, meta), meta, - pipeline, + stream_postprocessor, options, ) end @@ -293,7 +291,7 @@ function activate!( factornode::DeltaFnNode, layout::AbstractDeltaNodeDependenciesLayout, meta, - pipeline, + stream_postprocessors, options, ) foreach(getinterfaces(factornode)) do interface @@ -302,9 +300,9 @@ function activate!( ) end - scheduler = getscheduler(options) - addons = getaddons(options) + annotations = getannotations(options) rulefallback = getrulefallback(options) + callbacks = getcallbacks(options) # First we declare local marginal for `out` edge deltafn_apply_layout( @@ -312,10 +310,10 @@ function activate!( Val(:q_out), factornode, meta, - pipeline, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) # Second we declare how to compute a joint marginal over all inbound edges @@ -324,10 +322,10 @@ function activate!( Val(:q_ins), factornode, meta, - pipeline, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) # Second we declare message passing logic for out interface @@ -336,10 +334,10 @@ function activate!( Val(:m_out), factornode, meta, - pipeline, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) # At last we declare message passing logic for input interfaces @@ -348,10 +346,10 @@ function activate!( Val(:m_in), factornode, meta, - pipeline, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) end @@ -361,17 +359,20 @@ function score( ::Deterministic, node::DeltaFnNode, meta, - skip_strategy, - scheduler, + stream_postprocessors, ) where {T <: CountingReal} # TODO (make a function for `node.localmarginals.marginals[2]`) - qinsmarginal = apply_skip_filter( - getmarginal(node.localmarginals.marginals[2]), skip_strategy - ) + qinsmarginal = + get_stream_of_marginals(node.localmarginals.marginals[2]) |> + skip_initial() - stream = qinsmarginal |> schedule_on(scheduler) mapping = (marginal) -> convert(T, -score(DifferentialEntropy(), marginal)) - return stream |> map(T, mapping) + stream_of_scores = qinsmarginal |> map(T, mapping) + stream_of_scores = postprocess_stream_of_scores( + stream_postprocessors, stream_of_scores + ) + + return stream_of_scores end diff --git a/src/nodes/predefined/delta/layouts/cvi.jl b/src/nodes/predefined/delta/layouts/cvi.jl index 36c3f5525..b5996e424 100644 --- a/src/nodes/predefined/delta/layouts/cvi.jl +++ b/src/nodes/predefined/delta/layouts/cvi.jl @@ -21,20 +21,20 @@ function deltafn_apply_layout( ::Val{:q_out}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), Val(:q_out), factornode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) end @@ -44,20 +44,20 @@ function deltafn_apply_layout( ::Val{:q_ins}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), Val(:q_ins), factornode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) end @@ -67,10 +67,10 @@ function deltafn_apply_layout( ::Val{:m_out}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) let interface = factornode.out @@ -80,13 +80,13 @@ function deltafn_apply_layout( # CVI requires `q_ins` marginal_names = Val{(:ins,)}() - marginals_observable = combineLatestUpdates((getmarginal(factornode.localmarginals.marginals[2]),), PushNew()) + marginals_observable = combineLatestUpdates((get_stream_of_marginals(factornode.localmarginals.marginals[2]),), PushNew()) fform = functionalform(factornode) vtag = tag(interface) vconstraint = Marginalisation() - vmessageout = combineLatest( + stream_of_outbound_messages = combineLatest( (msgs_observable, marginals_observable), PushNew() ) @@ -98,23 +98,25 @@ function deltafn_apply_layout( msgs_names, marginal_names, meta, - addons, + annotations, factornode, rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap ) end - vmessageout = with_statics(factornode, vmessageout) - vmessageout = vmessageout |> map(AbstractMessage, mapping) - vmessageout = apply_pipeline_stage( - pipeline_stages, factornode, vtag, vmessageout + stream_of_outbound_messages = with_statics( + factornode, stream_of_outbound_messages ) - vmessageout = vmessageout |> schedule_on(scheduler) - - connect!(messageout(interface), vmessageout) + stream_of_outbound_messages = + stream_of_outbound_messages |> map(AbstractMessage, mapping) + stream_of_outbound_messages = postprocess_stream_of_outbound_messages( + stream_postprocessors, stream_of_outbound_messages + ) + set_stream_of_outbound_messages!(interface, stream_of_outbound_messages) end end @@ -124,19 +126,19 @@ function deltafn_apply_layout( ::Val{:m_in}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), Val(:m_in), factornode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) end diff --git a/src/nodes/predefined/delta/layouts/default.jl b/src/nodes/predefined/delta/layouts/default.jl index 315a8e3ee..c45032cf2 100644 --- a/src/nodes/predefined/delta/layouts/default.jl +++ b/src/nodes/predefined/delta/layouts/default.jl @@ -29,7 +29,7 @@ function with_statics( # We wait for the statics to be available, but ignore their actual values # They are being injected indirectly with the `fix` function upon node creation statics = map( - static -> messageout(static, 1), + static -> get_stream_of_outbound_messages(static, 1), FixedArguments.value.(factornode.statics), ) return combineLatest((stream, combineLatest(statics, PushNew()))) |> @@ -49,15 +49,17 @@ function deltafn_apply_layout( ::Val{:q_out}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) let out = factornode.out, localmarginal = factornode.localmarginals.marginals[1] # We simply subscribe on the marginal of the connected variable on `out` edge - setmarginal!(localmarginal, getmarginal(getvariable(out), IncludeAll())) + set_stream_of_marginals!( + localmarginal, get_stream_of_marginals(getvariable(out)) + ) end end @@ -67,21 +69,21 @@ function deltafn_apply_layout( ::Val{:q_ins}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) let out = factornode.out, ins = factornode.ins, localmarginal = factornode.localmarginals.marginals[2] cmarginal = MarginalObservable() - setmarginal!(localmarginal, cmarginal) + set_stream_of_marginals!(localmarginal, cmarginal) # By default to compute `q_ins` we need messages both from `:out` and `:ins` msgs_names = Val{(:out, :ins)}() - msgs_observable = combineLatestUpdates((messagein(out), combineLatestMessagesInUpdates(ins)), PushNew()) + msgs_observable = combineLatestUpdates((get_stream_of_inbound_messages(out), combineLatestMessagesInUpdates(ins)), PushNew()) # By default, we should not need any local marginals marginal_names = nothing @@ -92,6 +94,7 @@ function deltafn_apply_layout( mapping = MarginalMapping(fform, vtag, msgs_names, marginal_names, meta, factornode) marginalout = combineLatestUpdates((with_statics(factornode, msgs_observable), with_statics(factornode, marginals_observable)), PushNew(), Marginal, mapping, reset_vstatus) + marginalout = postprocess_stream_of_marginals(stream_postprocessors, marginalout) connect!(cmarginal, marginalout) end @@ -103,10 +106,10 @@ function deltafn_apply_layout( ::Val{:m_out}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) let out = factornode.out, ins = factornode.ins @@ -122,7 +125,7 @@ function deltafn_apply_layout( vtag = Val{:out}() vconstraint = Marginalisation() - vmessageout = combineLatest( + stream_of_outbound_messages = combineLatest( (msgs_observable, marginals_observable), PushNew() ) @@ -134,23 +137,25 @@ function deltafn_apply_layout( msgs_names, marginal_names, meta, - addons, + annotations, factornode, rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap ) end - vmessageout = with_statics(factornode, vmessageout) - vmessageout = vmessageout |> map(AbstractMessage, mapping) - vmessageout = apply_pipeline_stage( - pipeline_stages, factornode, vtag, vmessageout + stream_of_outbound_messages = with_statics( + factornode, stream_of_outbound_messages ) - vmessageout = vmessageout |> schedule_on(scheduler) - - connect!(messageout(out), vmessageout) + stream_of_outbound_messages = + stream_of_outbound_messages |> map(AbstractMessage, mapping) + stream_of_outbound_messages = postprocess_stream_of_outbound_messages( + stream_postprocessors, stream_of_outbound_messages + ) + set_stream_of_outbound_messages!(out, stream_of_outbound_messages) end end @@ -160,25 +165,25 @@ function deltafn_apply_layout( ::Val{:m_in}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) # For each outbound message from `in_k` edge we need an inbound message on this edge and a joint marginal over `:ins` edges foreach(factornode.ins) do interface msgs_names = Val{(:in,)}() - msgs_observable = combineLatestUpdates((messagein(interface),), PushNew()) + msgs_observable = combineLatestUpdates((get_stream_of_inbound_messages(interface),), PushNew()) marginal_names = Val{(:ins,)}() - marginals_observable = combineLatestUpdates((getmarginal(factornode.localmarginals.marginals[2]),), PushNew()) + marginals_observable = combineLatestUpdates((get_stream_of_marginals(factornode.localmarginals.marginals[2]),), PushNew()) fform = functionalform(factornode) vtag = tag(interface) vconstraint = Marginalisation() - vmessageout = combineLatest( + stream_of_outbound_messages = combineLatest( (msgs_observable, marginals_observable), PushNew() ) @@ -190,23 +195,25 @@ function deltafn_apply_layout( msgs_names, marginal_names, meta, - addons, + annotations, factornode, rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap ) end - vmessageout = with_statics(factornode, vmessageout) - vmessageout = vmessageout |> map(AbstractMessage, mapping) - vmessageout = apply_pipeline_stage( - pipeline_stages, factornode, vtag, vmessageout + stream_of_outbound_messages = with_statics( + factornode, stream_of_outbound_messages ) - vmessageout = vmessageout |> schedule_on(scheduler) - - connect!(messageout(interface), vmessageout) + stream_of_outbound_messages = + stream_of_outbound_messages |> map(AbstractMessage, mapping) + stream_of_outbound_messages = postprocess_stream_of_outbound_messages( + stream_postprocessors, stream_of_outbound_messages + ) + set_stream_of_outbound_messages!(interface, stream_of_outbound_messages) end end @@ -232,20 +239,20 @@ function deltafn_apply_layout( ::Val{:q_out}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), Val(:q_out), factornode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) end @@ -254,20 +261,20 @@ function deltafn_apply_layout( ::Val{:q_ins}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), Val(:q_ins), factornode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) end @@ -276,20 +283,20 @@ function deltafn_apply_layout( ::Val{:m_out}, factornode::DeltaFnNode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) return deltafn_apply_layout( DeltaFnDefaultRuleLayout(), Val(:m_out), factornode, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) end @@ -299,10 +306,10 @@ function deltafn_apply_layout( ::Val{:m_in}, factornode::DeltaFnNode{F}, meta, - pipeline_stages, - scheduler, - addons, + stream_postprocessors, + annotations, rulefallback, + callbacks, ) where {F} N = length(factornode.ins) @@ -312,7 +319,7 @@ function deltafn_apply_layout( # If we have only one `interface` we replace it with nothing # In other cases we remove the current index from the list of interfaces msgs_ins_stream = if N === 1 # `N` should be known at compile-time here so this `if` branch must be compiled out - of(Message(nothing, true, true, nothing)) + of(Message(nothing, true, true)) else combineLatestMessagesInUpdates( TupleTools.deleteat(factornode.ins, index) @@ -320,7 +327,7 @@ function deltafn_apply_layout( end msgs_names = Val{(:out, :ins)}() - msgs_observable = combineLatestUpdates((messagein(factornode.out), msgs_ins_stream), PushNew()) + msgs_observable = combineLatestUpdates((get_stream_of_inbound_messages(factornode.out), msgs_ins_stream), PushNew()) marginal_names = nothing marginals_observable = of(nothing) @@ -329,7 +336,7 @@ function deltafn_apply_layout( vtag = tag(interface) vconstraint = Marginalisation() - vmessageout = combineLatest( + stream_of_outbound_messages = combineLatest( (msgs_observable, marginals_observable), PushNew() ) @@ -341,22 +348,24 @@ function deltafn_apply_layout( msgs_names, marginal_names, meta, - addons, + annotations, factornode, rulefallback, + callbacks, ) (dependencies) -> DeferredMessage( dependencies[1], dependencies[2], messagemap ) end - vmessageout = with_statics(factornode, vmessageout) - vmessageout = vmessageout |> map(AbstractMessage, mapping) - vmessageout = apply_pipeline_stage( - pipeline_stages, factornode, vtag, vmessageout + stream_of_outbound_messages = with_statics( + factornode, stream_of_outbound_messages ) - vmessageout = vmessageout |> schedule_on(scheduler) - - connect!(messageout(interface), vmessageout) + stream_of_outbound_messages = + stream_of_outbound_messages |> map(AbstractMessage, mapping) + stream_of_outbound_messages = postprocess_stream_of_outbound_messages( + stream_postprocessors, stream_of_outbound_messages + ) + set_stream_of_outbound_messages!(interface, stream_of_outbound_messages) end end diff --git a/src/nodes/predefined/distribution/distribution.jl b/src/nodes/predefined/distribution/distribution.jl index 594ad6834..572adb35f 100644 --- a/src/nodes/predefined/distribution/distribution.jl +++ b/src/nodes/predefined/distribution/distribution.jl @@ -51,8 +51,12 @@ function activate!( factornode, options, ) - vmessageout = of(Message(factornode.distribution, true, false, nothing)) - connect!(messageout(getinterface(factornode, 1)), vmessageout) + stream_of_outbound_messages = of( + Message(factornode.distribution, true, false) + ) + set_stream_of_outbound_messages!( + getinterface(factornode, 1), stream_of_outbound_messages + ) return nothing end @@ -62,19 +66,21 @@ function score( ::FactorBoundFreeEnergy, node::StandaloneDistributionNode, meta, - skip_strategy, - scheduler, + stream_postprocessors, ) where {T <: CountingReal} - fnstream = let skip_strategy = skip_strategy, scheduler = scheduler - (localmarginal) -> - apply_skip_filter(getmarginal(localmarginal), skip_strategy) |> - schedule_on(scheduler) - end # `FactorBoundFreeEnergy` here is simply equal to `kldivergence` between the marginal and the outbound message - stream = fnstream(first(getmarginals(getlocalclusters(node)))) - return stream |> map( - T, - (marginal) -> - convert(T, score(KLDivergence(), marginal, node.distribution)), + stream_of_scores = + get_stream_of_marginals( + first(get_node_local_marginals(getlocalclusters(node))) + ) |> skip_initial() + stream_of_scores = + stream_of_scores |> map( + T, + (marginal) -> + convert(T, score(KLDivergence(), marginal, node.distribution)), + ) + stream_of_scores = postprocess_stream_of_scores( + stream_postprocessors, stream_of_scores ) + return stream_of_scores end diff --git a/src/nodes/predefined/gamma_mixture.jl b/src/nodes/predefined/gamma_mixture.jl index 6cd16885e..0c2bd0fe7 100644 --- a/src/nodes/predefined/gamma_mixture.jl +++ b/src/nodes/predefined/gamma_mixture.jl @@ -150,18 +150,17 @@ function collect_latest_marginals( marginals_observable = combineLatest( ( - getmarginal(getvariable(varinterface), IncludeAll()), + get_stream_of_marginals(getvariable(varinterface)), combineLatest( map( - (rate) -> getmarginal(getvariable(rate), IncludeAll()), + (rate) -> get_stream_of_marginals(getvariable(rate)), reverse(bsinterfaces), ), PushNew(), ), combineLatest( map( - (shape) -> - getmarginal(getvariable(shape), IncludeAll()), + (shape) -> get_stream_of_marginals(getvariable(shape)), reverse(asinterfaces), ), PushNew(), @@ -169,16 +168,16 @@ function collect_latest_marginals( ), PushNew(), ) |> map_to(( - getmarginal(getvariable(varinterface), IncludeAll()), + get_stream_of_marginals(getvariable(varinterface)), ManyOf( map( - (shape) -> getmarginal(getvariable(shape), IncludeAll()), + (shape) -> get_stream_of_marginals(getvariable(shape)), asinterfaces, ), ), ManyOf( map( - (rate) -> getmarginal(getvariable(rate), IncludeAll()), + (rate) -> get_stream_of_marginals(getvariable(rate)), bsinterfaces, ), ), @@ -199,7 +198,7 @@ function collect_latest_marginals( varinterface = marginal_dependencies[3] marginal_names = Val{(name(outinterface), name(switchinterface), name(varinterface))}() - marginals_observable = combineLatestUpdates((getmarginal(getvariable(outinterface), IncludeAll()), getmarginal(getvariable(switchinterface), IncludeAll()), getmarginal(getvariable(varinterface), IncludeAll())), PushNew()) + marginals_observable = combineLatestUpdates((get_stream_of_marginals(getvariable(outinterface)), get_stream_of_marginals(getvariable(switchinterface)), get_stream_of_marginals(getvariable(varinterface))), PushNew()) return marginal_names, marginals_observable end @@ -218,10 +217,7 @@ end AverageEnergy(), GammaShapeRate, Val{(:out, :α, :β)}(), - map( - (q) -> Marginal(q, false, false, nothing), - (q_out, q_a[i], q_b[i]), - ), + map((q) -> Marginal(q, false, false), (q_out, q_a[i], q_b[i])), nothing, ) end @@ -233,21 +229,18 @@ function score( ::Stochastic, node::GammaMixtureNode{N}, meta, - skip_strategy, - scheduler, + stream_postprocessors, ) where {T <: CountingReal, N} stream = combineLatest( ( - getmarginal(getvariable(node.out), skip_strategy) |> - schedule_on(scheduler), - getmarginal(getvariable(node.switch), skip_strategy) |> - schedule_on(scheduler), + get_stream_of_marginals(getvariable(node.out)) |> skip_initial(), + get_stream_of_marginals(getvariable(node.switch)) |> skip_initial(), ManyOfObservable( combineLatest( map( (as) -> - getmarginal(getvariable(as), skip_strategy) |> - schedule_on(scheduler), + get_stream_of_marginals(getvariable(as)) |> + skip_initial(), node.as, ), PushNew(), @@ -257,8 +250,8 @@ function score( combineLatest( map( (bs) -> - getmarginal(getvariable(bs), skip_strategy) |> - schedule_on(scheduler), + get_stream_of_marginals(getvariable(bs)) |> + skip_initial(), node.bs, ), PushNew(), @@ -292,7 +285,11 @@ function score( end end - return stream |> map(T, mapping) + stream_of_scores = stream |> map(T, mapping) + stream_of_scores = postprocess_stream_of_scores( + stream_postprocessors, stream_of_scores + ) + return stream_of_scores end ## Extra distribution for the Gamma Mixture diff --git a/src/nodes/predefined/mixture.jl b/src/nodes/predefined/mixture.jl index 67f864914..8d93987de 100644 --- a/src/nodes/predefined/mixture.jl +++ b/src/nodes/predefined/mixture.jl @@ -40,7 +40,7 @@ struct MixtureNode{N} <: AbstractFactorNode end ``` - Note: The `Mixture` node requires the `AddonLogScale` addon to be included in the addons. However, this addon is not available for most message update rules. RxInfer.jl, which uses `ReactiveMP.jl` under the hood, allows to pass addons in the [`infer`](https://reactivebayes.github.io/RxInfer.jl/stable/manuals/inference/overview/) function. Only for certain sum-product update rules these are included. For a detailed explanation on the `Mixture` node see the [Mixture node paper](https://www.mdpi.com/1099-4300/25/8/1138). + Note: The `Mixture` node requires the `LogScaleAnnotations` annotation processor to be enabled. RxInfer.jl, which uses `ReactiveMP.jl` under the hood, allows to pass annotation processors in the [`infer`](https://reactivebayes.github.io/RxInfer.jl/stable/manuals/inference/overview/) function. Only for certain sum-product update rules these are included. For a detailed explanation on the `Mixture` node see the [Mixture node paper](https://www.mdpi.com/1099-4300/25/8/1138). """ out :: NodeInterface switch :: NodeInterface @@ -168,7 +168,6 @@ function functional_dependencies( return message_dependencies, marginal_dependencies end -# create message observable for output or Mixture edge without pipeline constraints (the message towards the inputs are fine by default behaviour, i.e. they depend only on switch and output and no longer on all other inputs) function collect_latest_messages( ::MixtureNodeFunctionalDependencies, factornode::MixtureNode{N}, @@ -183,21 +182,28 @@ function collect_latest_messages( msgs_observable = combineLatest( ( - messagein(output_or_switch_interface), + get_stream_of_inbound_messages(output_or_switch_interface), combineLatest( - map((input) -> messagein(input), inputsinterfaces), + map( + (input) -> get_stream_of_inbound_messages(input), + inputsinterfaces, + ), PushNew(), ), ), PushNew(), ) |> map_to(( - messagein(output_or_switch_interface), - ManyOf(map((input) -> messagein(input), inputsinterfaces)), + get_stream_of_inbound_messages(output_or_switch_interface), + ManyOf( + map( + (input) -> get_stream_of_inbound_messages(input), + inputsinterfaces, + ), + ), )) return msgs_names, msgs_observable end -# create an observable that is used to compute the switch with pipeline constraints function collect_latest_messages( ::RequireMarginalFunctionalDependencies, factornode::MixtureNode{N}, @@ -210,21 +216,28 @@ function collect_latest_messages( msgs_observable = combineLatest( ( - messagein(switchinterface), + get_stream_of_inbound_messages(switchinterface), combineLatest( - map((input) -> messagein(input), inputsinterfaces), + map( + (input) -> get_stream_of_inbound_messages(input), + inputsinterfaces, + ), PushNew(), ), ), PushNew(), ) |> map_to(( - messagein(switchinterface), - ManyOf(map((input) -> messagein(input), inputsinterfaces)), + get_stream_of_inbound_messages(switchinterface), + ManyOf( + map( + (input) -> get_stream_of_inbound_messages(input), + inputsinterfaces, + ), + ), )) return msgs_names, msgs_observable end -# create an observable that is used to compute the output with pipeline constraints function collect_latest_messages( ::RequireMarginalFunctionalDependencies, ::MixtureNode{N}, @@ -235,13 +248,22 @@ function collect_latest_messages( msgs_names = Val{(name(inputsinterfaces[1]),)}() msgs_observable = combineLatest( - map((input) -> messagein(input), inputsinterfaces), PushNew() - ) |> - map_to((ManyOf(map((input) -> messagein(input), inputsinterfaces)),)) + map( + (input) -> get_stream_of_inbound_messages(input), + inputsinterfaces, + ), + PushNew(), + ) |> map_to(( + ManyOf( + map( + (input) -> get_stream_of_inbound_messages(input), + inputsinterfaces, + ), + ), + )) return msgs_names, msgs_observable end -# create an observable that is used to compute the input with pipeline constraints function collect_latest_messages( ::RequireMarginalFunctionalDependencies, factornode::MixtureNode{N}, @@ -251,7 +273,7 @@ function collect_latest_messages( msgs_names = Val{(name(outputinterface),)}() msgs_observable = combineLatestUpdates( - (messagein(outputinterface),), PushNew() + (get_stream_of_inbound_messages(outputinterface),), PushNew() ) return msgs_names, msgs_observable end @@ -272,7 +294,7 @@ function collect_latest_marginals( switchinterface = marginals[1] marginal_names = Val{(name(switchinterface),)}() - marginals_observable = combineLatestUpdates((getmarginal(getvariable(switchinterface), IncludeAll()),), PushNew()) + marginals_observable = combineLatestUpdates((get_stream_of_marginals(getvariable(switchinterface)),), PushNew()) return marginal_names, marginals_observable end diff --git a/src/nodes/predefined/normal_mixture.jl b/src/nodes/predefined/normal_mixture.jl index 70be66e65..5c5fd525a 100644 --- a/src/nodes/predefined/normal_mixture.jl +++ b/src/nodes/predefined/normal_mixture.jl @@ -156,17 +156,17 @@ function collect_latest_marginals( marginals_observable = combineLatest( ( - getmarginal(getvariable(varinterface), IncludeAll()), + get_stream_of_marginals(getvariable(varinterface)), combineLatest( map( - (prec) -> getmarginal(getvariable(prec), IncludeAll()), + (prec) -> get_stream_of_marginals(getvariable(prec)), reverse(precsinterfaces), ), PushNew(), ), combineLatest( map( - (mean) -> getmarginal(getvariable(mean), IncludeAll()), + (mean) -> get_stream_of_marginals(getvariable(mean)), reverse(meansinterfaces), ), PushNew(), @@ -174,16 +174,16 @@ function collect_latest_marginals( ), PushNew(), ) |> map_to(( - getmarginal(getvariable(varinterface), IncludeAll()), + get_stream_of_marginals(getvariable(varinterface)), ManyOf( map( - (mean) -> getmarginal(getvariable(mean), IncludeAll()), + (mean) -> get_stream_of_marginals(getvariable(mean)), meansinterfaces, ), ), ManyOf( map( - (prec) -> getmarginal(getvariable(prec), IncludeAll()), + (prec) -> get_stream_of_marginals(getvariable(prec)), precsinterfaces, ), ), @@ -204,7 +204,7 @@ function collect_latest_marginals( varinterface = marginal_dependencies[3] marginal_names = Val{(name(outinterface), name(switchinterface), name(varinterface))}() - marginals_observable = combineLatestUpdates((getmarginal(getvariable(outinterface), IncludeAll()), getmarginal(getvariable(switchinterface), IncludeAll()), getmarginal(getvariable(varinterface), IncludeAll())), PushNew()) + marginals_observable = combineLatestUpdates((get_stream_of_marginals(getvariable(outinterface)), get_stream_of_marginals(getvariable(switchinterface)), get_stream_of_marginals(getvariable(varinterface))), PushNew()) return marginal_names, marginals_observable end @@ -225,7 +225,7 @@ function avg_energy_nm(::Type{Univariate}, q_out, q_m, q_p, z_bar, i) AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}(), - map((q) -> Marginal(q, false, false, nothing), (q_out, q_m[i], q_p[i])), + map((q) -> Marginal(q, false, false), (q_out, q_m[i], q_p[i])), nothing, ) end @@ -235,7 +235,7 @@ function avg_energy_nm(::Type{Multivariate}, q_out, q_m, q_p, z_bar, i) AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}(), - map((q) -> Marginal(q, false, false, nothing), (q_out, q_m[i], q_p[i])), + map((q) -> Marginal(q, false, false), (q_out, q_m[i], q_p[i])), nothing, ) end @@ -246,21 +246,18 @@ function score( ::Stochastic, node::NormalMixtureNode{N}, meta, - skip_strategy, - scheduler, + stream_postprocessors, ) where {T <: CountingReal, N} stream = combineLatest( ( - getmarginal(getvariable(node.out), skip_strategy) |> - schedule_on(scheduler), - getmarginal(getvariable(node.switch), skip_strategy) |> - schedule_on(scheduler), + get_stream_of_marginals(getvariable(node.out)) |> skip_initial(), + get_stream_of_marginals(getvariable(node.switch)) |> skip_initial(), ManyOfObservable( combineLatest( map( (mean) -> - getmarginal(getvariable(mean), skip_strategy) |> - schedule_on(scheduler), + get_stream_of_marginals(getvariable(mean)) |> + skip_initial(), node.means, ), PushNew(), @@ -270,8 +267,8 @@ function score( combineLatest( map( (prec) -> - getmarginal(getvariable(prec), skip_strategy) |> - schedule_on(scheduler), + get_stream_of_marginals(getvariable(prec)) |> + skip_initial(), node.precs, ), PushNew(), @@ -308,5 +305,10 @@ function score( end end - return stream |> map(T, mapping) + stream_of_scores = stream |> map(T, mapping) + stream_of_scores = postprocess_stream_of_scores( + stream_postprocessors, stream_of_scores + ) + + return stream_of_scores end diff --git a/src/pipeline/async.jl b/src/pipeline/async.jl deleted file mode 100644 index 1d39978fc..000000000 --- a/src/pipeline/async.jl +++ /dev/null @@ -1,13 +0,0 @@ -export AsyncPipelineStage - -import Rocket: async - -""" - AsyncPipelineStage <: AbstractPipelineStage - -Applies the `async()` operator from `Rocket.jl` library to the given pipeline -""" -struct AsyncPipelineStage <: AbstractPipelineStage end - -apply_pipeline_stage(::AsyncPipelineStage, factornode, tag, stream) = - stream |> async() diff --git a/src/pipeline/discontinue.jl b/src/pipeline/discontinue.jl deleted file mode 100644 index 1b18e133e..000000000 --- a/src/pipeline/discontinue.jl +++ /dev/null @@ -1,11 +0,0 @@ -export DiscontinuePipelineStage - -""" - DiscontinuePipelineStage <: AbstractPipelineStage - -Applies the `discontinue()` operator from `Rocket.jl` library to the given pipeline -""" -struct DiscontinuePipelineStage <: AbstractPipelineStage end - -apply_pipeline_stage(::DiscontinuePipelineStage, factornode, tag, stream) = - stream |> discontinue() diff --git a/src/pipeline/logger.jl b/src/pipeline/logger.jl deleted file mode 100644 index e9c412ee8..000000000 --- a/src/pipeline/logger.jl +++ /dev/null @@ -1,38 +0,0 @@ -export LoggerPipelineStage - -""" - LoggerPipelineStage <: AbstractPipelineStage - -Logs all updates from `stream` into `output` - -# Arguments -- `output`: (optional), an output stream used to print log statements -""" -struct LoggerPipelineStage{T} <: AbstractPipelineStage - output::T - prefix::String -end - -LoggerPipelineStage() = LoggerPipelineStage(Core.stdout, "Log") -LoggerPipelineStage(output::IO) = LoggerPipelineStage(output, "Log") -LoggerPipelineStage(prefix::String) = LoggerPipelineStage(Core.stdout, prefix) - -logger_pipeline_stage_println(logger::LoggerPipelineStage, something::Any) = logger_pipeline_stage_println( - logger, - logger.output, - logger_pipeline_stage_append_prefix(logger, something), -) - -logger_pipeline_stage_println(logger::LoggerPipelineStage, output::Core.CoreSTDOUT, something) = Core.println( - output, something -) -logger_pipeline_stage_println(logger::LoggerPipelineStage, output, something) = println( - output, something -) - -logger_pipeline_stage_append_prefix(logger::LoggerPipelineStage, something) = string( - "[", logger.prefix, "]: ", something -) - -apply_pipeline_stage(logger::LoggerPipelineStage, factornode, tag::Val{T}, stream) where {T} = stream |> tap((v) -> logger_pipeline_stage_println(logger, string("[", functionalform(factornode), "][", T, "]: ", v))) -apply_pipeline_stage(logger::LoggerPipelineStage, factornode, tag::Tuple{Val{T}, Int}, stream) where {T} = stream |> tap((v) -> logger_pipeline_stage_println(logger, string("[", functionalform(factornode), "][", T, ":", tag[2], "]: ", v))) diff --git a/src/pipeline/pipeline.jl b/src/pipeline/pipeline.jl deleted file mode 100644 index 88587519d..000000000 --- a/src/pipeline/pipeline.jl +++ /dev/null @@ -1,68 +0,0 @@ -export AbstractPipelineStage, - EmptyPipelineStage, CompositePipelineStage, apply_pipeline_stage - -import Base: + - -## Abstract Custom Pipeline Stage - -""" - AbstractPipelineStage - -An abstract type for all custom pipelines stages -""" -abstract type AbstractPipelineStage end - -""" - apply_pipeline_stage(stage, factornode, tag, stream) - -Applies a given pipeline stage to the `stream` argument given `factornode` and `tag` of an edge. -""" -function apply_pipeline_stage end - -## Default empty pipeline - -""" - EmptyPipelineStage <: AbstractPipelineStage - -Dummy empty pipeline stage that does not modify the original pipeline. -""" -struct EmptyPipelineStage <: AbstractPipelineStage end - -apply_pipeline_stage(::EmptyPipelineStage, factornode, tag, stream) = stream - -## Composite pipeline - -""" - CompositePipelineStage{T} <: AbstractPipelineStage - -Composite pipeline stage that consists of multiple inner pipeline stages -""" -struct CompositePipelineStage{T} <: AbstractPipelineStage - stages::T -end - -apply_pipeline_stage(composite::CompositePipelineStage, factornode, tag, stream) = reduce( - (stream, stage) -> apply_pipeline_stage(stage, factornode, tag, stream), - composite.stages; - init = stream, -) - -Base.:+(stage::AbstractPipelineStage) = stage - -Base.:+(left::EmptyPipelineStage, right::EmptyPipelineStage) = EmptyPipelineStage() -Base.:+(left::EmptyPipelineStage, right::AbstractPipelineStage) = right -Base.:+(left::AbstractPipelineStage, right::EmptyPipelineStage) = left -Base.:+(left::AbstractPipelineStage, right::AbstractPipelineStage) = CompositePipelineStage((left, right)) -Base.:+(left::AbstractPipelineStage, right::CompositePipelineStage) = CompositePipelineStage((left, right.stages...)) -Base.:+(left::CompositePipelineStage, right::AbstractPipelineStage) = CompositePipelineStage((left.stages..., right)) -Base.:+(left::CompositePipelineStage, right::CompositePipelineStage) = CompositePipelineStage((left.stages..., right.stages...)) - -""" - collect_pipeline(nodetype, pipeline) - -This function converts given pipeline to a correct internal pipeline representation for a factor given node. -""" -function collect_pipeline end - -collect_pipeline(any, ::Nothing) = EmptyPipelineStage() -collect_pipeline(any, something) = something diff --git a/src/pipeline/scheduled.jl b/src/pipeline/scheduled.jl deleted file mode 100644 index de96336a7..000000000 --- a/src/pipeline/scheduled.jl +++ /dev/null @@ -1,63 +0,0 @@ -export ScheduleOnPipelineStage, schedule_updates - -import Rocket: release! - -""" - ScheduleOnPipelineStage{S} <: AbstractPipelineStage - -Applies the `schedule_on()` operator from `Rocket.jl` library to the given pipeline with a provided `scheduler` - -# Arguments -- `scheduler`: scheduler to schedule updates on. Must be compatible with `Rocket.jl` library and `schedule_on()` operator. -""" -struct ScheduleOnPipelineStage{S} <: AbstractPipelineStage - scheduler::S -end - -apply_pipeline_stage(stage::ScheduleOnPipelineStage, factornode, tag, stream) = - stream |> schedule_on(stage.scheduler) - -Rocket.release!(stage::ScheduleOnPipelineStage) = release!(stage.scheduler) -Rocket.release!(stages::NTuple{N, <:ScheduleOnPipelineStage}) where {N} = foreach(release!, stages) -Rocket.release!(stages::AbstractArray{<:ScheduleOnPipelineStage}) = foreach(release!, stages) - -update!(stage::ScheduleOnPipelineStage) = release!(stage.scheduler) -update!(stages::NTuple{N, <:ScheduleOnPipelineStage}) where {N} = foreach(update!, stages) -update!(stages::AbstractArray{<:ScheduleOnPipelineStage}) = foreach(update!, stages) - -function _schedule_updates end - -__schedule_updates(var::AbstractVariable) = __schedule_updates((var,)) -__schedule_updates(vars::NTuple{N, <:AbstractVariable}) where {N} = __schedule_updates(ScheduleOnPipelineStage(PendingScheduler()), vars) -__schedule_updates(vars::AbstractArray{<:AbstractVariable}) = __schedule_updates(ScheduleOnPipelineStage(PendingScheduler()), vars) - -__schedule_updates(pipeline_stage::ScheduleOnPipelineStage, var::AbstractVariable) = __schedule_updates( - pipeline_stage, (var,) -) - -function __schedule_updates( - pipeline_stage::ScheduleOnPipelineStage, vars::NTuple{N, <:AbstractVariable} -) where {N} - foreach((v) -> add_pipeline_stage!(v, pipeline_stage), vars) - return pipeline_stage -end - -function __schedule_updates( - pipeline_stage::ScheduleOnPipelineStage, - vars::AbstractArray{<:AbstractVariable}, -) - foreach((v) -> add_pipeline_stage!(v, pipeline_stage), vars) - return pipeline_stage -end - -""" - schedule_updates(variables...; pipeline_stage = ScheduleOnPipelineStage(PendingScheduler())) - -Schedules posterior marginal updates for given variables using `stage`. By default creates `ScheduleOnPipelineStage` with `PendingScheduler()` from `Rocket.jl` library. -Returns a scheduler with `release!` method available to release all scheduled updates. -""" -function schedule_updates( - args...; pipeline_stage = ScheduleOnPipelineStage(PendingScheduler()) -) - return map((arg) -> __schedule_updates(pipeline_stage, arg), args) -end diff --git a/src/postprocessors.jl b/src/postprocessors.jl new file mode 100644 index 000000000..d6d7ad358 --- /dev/null +++ b/src/postprocessors.jl @@ -0,0 +1,157 @@ +""" + AbstractStreamPostprocessor + +Abstract supertype for **stream postprocessors** — composable transformations +applied to the reactive observables produced during graph activation. + +A stream postprocessor wraps a Rocket.jl observable and returns a new observable +of the same element type. The same postprocessor can be applied to three +different kinds of streams produced by the inference engine, each with its own +entry point: + +- [`ReactiveMP.postprocess_stream_of_outbound_messages`](@ref) — the stream of + outbound [`Message`](@ref)s leaving a factor node interface (or a leg of an + [`ReactiveMP.EqualityChain`](@ref)). +- [`ReactiveMP.postprocess_stream_of_marginals`](@ref) — the stream of + [`Marginal`](@ref)s emitted by a [`RandomVariable`](@ref) or by the local + cluster of a factor node. +- [`ReactiveMP.postprocess_stream_of_scores`](@ref) — the stream of free-energy + contributions. + +Stream postprocessors are attached to an inference run via +[`ReactiveMP.FactorNodeActivationOptions`](@ref) and +[`ReactiveMP.RandomVariableActivationOptions`](@ref). Multiple postprocessors +can be chained with [`ReactiveMP.CompositeStreamPostprocessor`](@ref). + +# Built-in implementations + +- [`ReactiveMP.ScheduleOnStreamPostprocessor`](@ref) — redirects the + computation onto a custom Rocket.jl scheduler (e.g. `PendingScheduler`, + `AsyncScheduler`). +- [`ReactiveMP.CompositeStreamPostprocessor`](@ref) — applies a sequence of + postprocessors in order. + +See also: [`ReactiveMP.postprocess_stream_of_outbound_messages`](@ref), +[`ReactiveMP.postprocess_stream_of_marginals`](@ref), +[`ReactiveMP.postprocess_stream_of_scores`](@ref). +""" +abstract type AbstractStreamPostprocessor end + +""" + postprocess_stream_of_outbound_messages(postprocessor, stream) + +Apply `postprocessor` to a stream of outbound [`Message`](@ref)s and return the +transformed stream. Called by [`ReactiveMP.activate!`](@ref) on every outbound +message stream produced by a factor node interface. + +The default fallback for `::Nothing` returns `stream` unchanged. Subtypes of +[`ReactiveMP.AbstractStreamPostprocessor`](@ref) may override this method to +e.g. redirect emissions to a Rocket.jl scheduler. +""" +function postprocess_stream_of_outbound_messages end + +""" + postprocess_stream_of_marginals(postprocessor, stream) + +Apply `postprocessor` to a stream of [`Marginal`](@ref)s and return the +transformed stream. Called by [`ReactiveMP.activate!`](@ref) on every marginal +stream produced for a [`RandomVariable`](@ref) or for a local cluster of a +factor node. + +The default fallback for `::Nothing` returns `stream` unchanged. Subtypes of +[`ReactiveMP.AbstractStreamPostprocessor`](@ref) may override this method. +""" +function postprocess_stream_of_marginals end + +""" + postprocess_stream_of_scores(postprocessor, stream) + +Apply `postprocessor` to a stream of free-energy score contributions and return +the transformed stream. + +The default fallback for `::Nothing` returns `stream` unchanged. Subtypes of +[`ReactiveMP.AbstractStreamPostprocessor`](@ref) may override this method. +""" +function postprocess_stream_of_scores end + +""" + postprocess_stream_of_outbound_messages(::Nothing, stream) = stream + +Pass-through fallback: when no stream postprocessor is configured, outbound +message streams are returned unchanged. +""" +postprocess_stream_of_outbound_messages(::Nothing, stream) = stream + +""" + postprocess_stream_of_marginals(::Nothing, stream) = stream + +Pass-through fallback: when no stream postprocessor is configured, marginal +streams are returned unchanged. +""" +postprocess_stream_of_marginals(::Nothing, stream) = stream + +""" + postprocess_stream_of_scores(::Nothing, stream) = stream + +Pass-through fallback: when no stream postprocessor is configured, score +streams are returned unchanged. +""" +postprocess_stream_of_scores(::Nothing, stream) = stream + +""" + CompositeStreamPostprocessor{T} <: AbstractStreamPostprocessor + +A [`ReactiveMP.AbstractStreamPostprocessor`](@ref) that applies a sequence of +inner postprocessors in order. The output of stage `i` is fed as the input of +stage `i + 1`, for each of the three stream kinds independently. + +# Fields +- `stages::T` — a tuple (or any iterable) of postprocessors to apply in order. + +# Example + +```julia +composite = CompositeStreamPostprocessor(( + ScheduleOnStreamPostprocessor(PendingScheduler()), + MyCustomPostprocessor(), +)) +``` + +See also: [`ReactiveMP.postprocess_stream_of_outbound_messages`](@ref), +[`ReactiveMP.postprocess_stream_of_marginals`](@ref), +[`ReactiveMP.postprocess_stream_of_scores`](@ref). +""" +struct CompositeStreamPostprocessor{T} <: AbstractStreamPostprocessor + stages::T +end + +function postprocess_stream_of_outbound_messages( + composite::CompositeStreamPostprocessor, stream +) + return reduce( + (stream, stage) -> + postprocess_stream_of_outbound_messages(stage, stream), + composite.stages; + init = stream, + ) +end + +function postprocess_stream_of_marginals( + composite::CompositeStreamPostprocessor, stream +) + return reduce( + (stream, stage) -> postprocess_stream_of_marginals(stage, stream), + composite.stages; + init = stream, + ) +end + +function postprocess_stream_of_scores( + composite::CompositeStreamPostprocessor, stream +) + return reduce( + (stream, stage) -> postprocess_stream_of_scores(stage, stream), + composite.stages; + init = stream, + ) +end diff --git a/src/postprocessors/scheduled.jl b/src/postprocessors/scheduled.jl new file mode 100644 index 000000000..4cb37fd61 --- /dev/null +++ b/src/postprocessors/scheduled.jl @@ -0,0 +1,48 @@ +import Rocket: release! + +""" + ScheduleOnStreamPostprocessor{S} <: AbstractStreamPostprocessor + +A [`ReactiveMP.AbstractStreamPostprocessor`](@ref) that redirects every emission +of the wrapped stream onto a Rocket.jl scheduler via the `schedule_on(scheduler)` +operator. This is the standard way to control *when* downstream subscribers +observe updates — for example, to batch a wave of inbound observations into a +single propagation step using a `PendingScheduler`, or to move work onto a +worker thread using an `AsyncScheduler`. + +The same scheduler is applied to all three stream kinds (outbound messages, +marginals, scores), which makes `ScheduleOnStreamPostprocessor` the direct +successor of the v5/early-v6 `ScheduleOnPipelineStage` + node-level scheduler +pair. + +# Fields +- `scheduler::S` — a Rocket.jl scheduler. Must be compatible with + `Rocket.schedule_on`. + +# Releasing scheduled updates + +If the wrapped scheduler buffers updates (e.g. `PendingScheduler`), call +`Rocket.release!` on the postprocessor to flush them. `release!` is also +defined for tuples and arrays of `ScheduleOnStreamPostprocessor`s for +convenience. + +See also: [`ReactiveMP.AbstractStreamPostprocessor`](@ref), +[`ReactiveMP.CompositeStreamPostprocessor`](@ref). +""" +struct ScheduleOnStreamPostprocessor{S} <: AbstractStreamPostprocessor + scheduler::S +end + +postprocess_stream_of_outbound_messages( + p::ScheduleOnStreamPostprocessor, stream +) = stream |> schedule_on(p.scheduler) + +postprocess_stream_of_marginals(p::ScheduleOnStreamPostprocessor, stream) = + stream |> schedule_on(p.scheduler) + +postprocess_stream_of_scores(p::ScheduleOnStreamPostprocessor, stream) = + stream |> schedule_on(p.scheduler) + +Rocket.release!(stage::ScheduleOnStreamPostprocessor) = release!(stage.scheduler) +Rocket.release!(stages::NTuple{N, <:ScheduleOnStreamPostprocessor}) where {N} = foreach(release!, stages) +Rocket.release!(stages::AbstractArray{<:ScheduleOnStreamPostprocessor}) = foreach(release!, stages) diff --git a/src/rule.jl b/src/rule.jl index 60296c6da..ac2b7625a 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -23,7 +23,7 @@ This function is used to compute an outbound message for a given node - `qnames`: Ordered marginal names in form of the Val type, eg. `::Val{ (:mean, :precision) }` - `marginals`: Tuple of marginals of the same length as `qnames` used to compute an outbound message - `meta`: Extra meta information -- `addons`: Extra addons information +- `annotations`: An `AnnotationDict` for writing annotations during rule execution - `__node`: Node reference For all available rules, see `ReactiveMP.print_rules_table()`. @@ -255,7 +255,7 @@ function call_rule_macro_parse_fn_args(inputs; specname, prefix, proxy) $(map(v -> apply_proxy(v, proxy), any.args[2:end])...), ))) end - return :($(proxy)($any, false, false, nothing)) + return :($(proxy)($any, false, false)) end names_arg = isempty(names) ? :nothing : :(Val{$(Expr(:tuple, map(n -> QuoteNode(Symbol(string(n)[(lprefix + 1):end])), names)...))}()) @@ -352,7 +352,7 @@ function rule_function_expression( metatype, whereargs, ) - addonsvar = gensym(:addons) + annotationsvar = gensym(:annotations) nodevar = gensym(:node) return quote function ReactiveMP.rule( @@ -364,13 +364,13 @@ function rule_function_expression( marginals_names::$(q_names), marginals::$(q_types), meta::$(metatype), - $(addonsvar), + $(annotationsvar), $(nodevar), ) where {$(whereargs...)} local getnode = () -> $nodevar local getnodefn = (args...) -> ReactiveMP.nodefunction($nodevar, args...) - local getaddons = () -> $addonsvar + local getannotations = () -> $annotationsvar $(body()) end end @@ -543,26 +543,10 @@ macro rule(fform, lambda) whereargs, ) do return quote - local _addons = getaddons() - # This trick allows us to use arbitrary control-flow logic - # inside rules, e.g. if-else-returns etc, however - # it makes it not-type-stable with respect to addons - # on my (bvdmitri) benchmarks it accounted for 2-3% slowdown - # when using addons, which is IMO acceptable, but can be changed - # in the future by banning return statements from the `@rule` macro - # I'm against of manually removing return statements as - # it is very hard to implement correctly, I would rather make it more stable - # when fast but error-prone - # Another way to speed-up this part a little bit would be to refactor addons - # in such a way that their structure is always known to the compiler and type stable - local _messagebody = () -> begin - $(on_index_init) - $(m_init_block...) - $(q_init_block...) - $(body) - end - local _message = _messagebody() - return _message, _addons + $(on_index_init) + $(m_init_block...) + $(q_init_block...) + $(body) end end ) @@ -572,22 +556,25 @@ macro rule(fform, lambda) end """ - @call_rule NodeType(:edge, Constraint) (argument1 = value1, argument2 = value2, ..., [ meta = ..., addons = ... ]) + @call_rule NodeType(:edge, Constraint) (argument1 = value1, argument2 = value2, ..., [ meta = ..., annotations = ... ]) -The `@call_rule` macro helps to call the `rule` method with an easier syntax. +The `@call_rule` macro helps to call the `rule` method with an easier syntax. The structure of the macro is almost the same as in the `@rule` macro, but there is no `begin ... end` block, but instead each argument must have a specified value with the `=` operator. The `@call_rule` accepts optional list of options before the functional form specification, for example: ```julia -@call_rule [ return_addons = true ] NodeType(:edge, Constraint) (argument1 = value1, argument2 = value2, ..., [ meta = ..., addons = ... ]) +@call_rule [ fallback = MyFallback() ] NodeType(:edge, Constraint) (argument1 = value1, argument2 = value2, ..., [ meta = ..., annotations = ... ]) ``` The list of available options is: -- `return_addons` - forces the `@call_rule` to return the tuple of `(result, addons)` - `fallback` - specifies the fallback rule to use in case the rule is not defined for the given `NodeType` and specified arguments +The optional `annotations` keyword accepts an `AnnotationDict` that will be passed to the rule body. +Any annotations written during rule execution (e.g. via `@logscale`) will be stored in this dict. +If not provided, a temporary `AnnotationDict()` is created internally. + See also: [`@rule`](@ref), [`rule`](@ref), [`@call_marginalrule`](@ref) """ macro call_rule(options, fform, args) @@ -605,10 +592,16 @@ function call_rule_expression(options, fform, args) @capture( args, - (inputs__, meta = meta_, addons = addons_) | - (inputs__, addons = addons_) | (inputs__, meta = meta_) | (inputs__,) + (inputs__, meta = meta_, annotations = annotations_) | + (inputs__, annotations = annotations_) | (inputs__, meta = meta_) | + (inputs__,) ) || error("Error in macro. Arguments specification is incorrect") + # Default annotations to a fresh AnnotationDict if not provided + if annotations === nothing + annotations = :(ReactiveMP.AnnotationDict()) + end + fuppertype = MacroHelpers.upper_type(fformtype) fbottomtype = MacroHelpers.bottom_type(fformtype) on_type, on_index, on_index_init = rule_macro_parse_on_tag(on) @@ -640,8 +633,6 @@ function call_rule_expression(options, fform, args) on_arg = call_rule_macro_construct_on_arg(on_type, on_index) # Options - # Option 1. Modifies the output of the `@call_rule` macro and returns a tuple of the result and the enabled addons - return_addons = false fallback = nothing if !isnothing(options) @@ -652,9 +643,7 @@ function call_rule_expression(options, fform, args) @capture(option, key_ = value_) || error( "Error in macro. An options should be in a form of `option = value`, got $(option).", ) - if key === :return_addons - return_addons = Bool(value) - elseif key === :fallback + if key === :fallback fallback = value else @warn "Unknown option in the `@call_rule` macro: $(option)" @@ -663,10 +652,8 @@ function call_rule_expression(options, fform, args) end __rule_result_sym = gensym(:call_rule_result) - __distribution_sym = gensym(:call_rule_distribution) - __addons_sym = gensym(:call_rule_addons) - call = quote + output = quote local $__rule_result_sym = ReactiveMP.rule( $fbottomtype, $on_arg, @@ -676,7 +663,7 @@ function call_rule_expression(options, fform, args) $q_names_arg, $q_values_arg, $meta, - $addons, + $annotations, $node, ) if ($__rule_result_sym) isa ReactiveMP.RuleMethodError && @@ -690,19 +677,13 @@ function call_rule_expression(options, fform, args) $q_names_arg, $q_values_arg, $meta, - $addons, + $annotations, $node, ) elseif ($__rule_result_sym) isa ReactiveMP.RuleMethodError throw($__rule_result_sym) end - local $(__distribution_sym), $(__addons_sym) = $(__rule_result_sym) - end - - output = if !return_addons - :($call; $__distribution_sym) - else - :($call) + $__rule_result_sym end return esc(output) @@ -1558,11 +1539,11 @@ struct RuleMethodError qnames marginals meta - addons + annotations node end -rule(fform, on, vconstraint, mnames, messages, qnames, marginals, meta, addons, __node) = RuleMethodError( +rule(fform, on, vconstraint, mnames, messages, qnames, marginals, meta, annotations, __node) = RuleMethodError( fform, on, vconstraint, @@ -1571,7 +1552,7 @@ rule(fform, on, vconstraint, mnames, messages, qnames, marginals, meta, addons, qnames, marginals, meta, - addons, + annotations, __node, ) @@ -1629,9 +1610,6 @@ function Base.showerror(io::IO, error::RuleMethodError) println(io, "\n\nPossible fix, define:\n") println(io, possible_fix_definition) - if !isnothing(error.addons) - println(io, "\n\nEnabled addons: ", error.addons, "\n") - end node_rules = filter( m -> ReactiveMP.get_node_from_rule_method(m) == spec_fform, @@ -1704,7 +1682,7 @@ function Base.showerror(io::IO, error::RuleMethodError) println(io, "rule.qnames: ", error.qnames) println(io, "rule.marginals: ", error.marginals) println(io, "rule.meta: ", error.meta) - println(io, "rule.addons: ", error.addons) + println(io, "rule.annotations: ", error.annotations) end end diff --git a/src/rules/discrete_transition/a.jl b/src/rules/discrete_transition/a.jl index 2a6b59ebd..cd643fd53 100644 --- a/src/rules/discrete_transition/a.jl +++ b/src/rules/discrete_transition/a.jl @@ -23,13 +23,12 @@ function ReactiveMP.rule( }, }, meta::Any, - addons::Any, + annotations::Any, ::Any, ) where {M, m_names} # Special case, if there is only one marginal, we can return the result directly. if M === 1 - return DirichletCollection(components(getdata(first(marginals))) .+ 1), - addons + return DirichletCollection(components(getdata(first(marginals))) .+ 1) end # First, we have to count the number of dimensions that we need for the contingency matrix. c = 0 @@ -48,5 +47,5 @@ function ReactiveMP.rule( result = result .* v end result = Contingency(result) - return DirichletCollection(components(result) .+ 1), addons + return DirichletCollection(components(result) .+ 1) end diff --git a/src/rules/discrete_transition/categoricals.jl b/src/rules/discrete_transition/categoricals.jl index 57f19676b..3b7774db9 100644 --- a/src/rules/discrete_transition/categoricals.jl +++ b/src/rules/discrete_transition/categoricals.jl @@ -230,14 +230,13 @@ function ReactiveMP.rule( }, }, meta::Any, - addons::Any, + annotations::Any, ::Any, ) where {S, M, N, mes_names, mar_names} q_a = marginals[findfirst(==(:a), mar_names)] return discrete_transition_structured_message_rule( mes_names, messages, mar_names, marginals, q_a - ), - addons + ) end function ReactiveMP.rule( @@ -258,12 +257,11 @@ function ReactiveMP.rule( }, }, meta::Any, - addons::Any, + annotations::Any, ::Any, ) where {S, M, mar_names} q_a = marginals[findfirst(==(:a), mar_names)] return discrete_transition_structured_message_rule( (), (), mar_names, marginals, q_a - ), - addons + ) end diff --git a/src/rules/fallbacks.jl b/src/rules/fallbacks.jl index 5fbb15437..6e21fa45e 100644 --- a/src/rules/fallbacks.jl +++ b/src/rules/fallbacks.jl @@ -80,7 +80,7 @@ function rulefallback_nodefunction( qnames, marginals, meta, - addons, + annotations, __node, ) return rulefallback_nodefunction( @@ -94,7 +94,7 @@ function rulefallback_nodefunction( qnames, marginals, meta, - addons, + annotations, __node, ) end @@ -110,14 +110,14 @@ function rulefallback_nodefunction( qnames, marginals, meta, - addons, + annotations, __node, ) vals = _mergevals(mnames, qnames) means = _extractvalues(fallback.extractfn, messages, marginals) kwargs = NamedTuple{vals}(means) fn = ReactiveMP.nodefunction(fform, on; kwargs...) - return FallbackNodeFunctionUnnormalizedLogPdf(fn), addons + return FallbackNodeFunctionUnnormalizedLogPdf(fn) end function rulefallback_nodefunction( @@ -131,7 +131,7 @@ function rulefallback_nodefunction( qnames, marginals, meta, - addons, + annotations, __node, ) error( @@ -150,7 +150,7 @@ function rulefallback_nodefunction( qnames, marginals, meta, - addons, + annotations, __node, ) error( diff --git a/src/rules/gamma_mixture/switch.jl b/src/rules/gamma_mixture/switch.jl index 12444684c..3ec5893e5 100644 --- a/src/rules/gamma_mixture/switch.jl +++ b/src/rules/gamma_mixture/switch.jl @@ -6,7 +6,7 @@ AverageEnergy(), GammaShapeRate, Val{(:out, :α, :β)}(), - map((q) -> Marginal(q, false, false, nothing), (q_out, a, b)), + map((q) -> Marginal(q, false, false), (q_out, a, b)), nothing, ) end diff --git a/src/rules/mixture/inputs.jl b/src/rules/mixture/inputs.jl index 73ac781d9..cd51f7c48 100644 --- a/src/rules/mixture/inputs.jl +++ b/src/rules/mixture/inputs.jl @@ -1,7 +1,7 @@ @rule Mixture((:inputs, k), Marginalisation) (m_out::Any, m_switch::Any) = begin # `messages` are available from the `@rule` macro itself - @logscale getlogscale(messages[1]) + - getlogscale(messages[2]) + + @logscale getlogscale(ReactiveMP.getannotations(messages[1])) + + getlogscale(ReactiveMP.getannotations(messages[2])) + log(probvec(messages[2])[k]) return m_out end diff --git a/src/rules/mixture/out.jl b/src/rules/mixture/out.jl index 11931fbd8..844ed83f5 100644 --- a/src/rules/mixture/out.jl +++ b/src/rules/mixture/out.jl @@ -3,10 +3,14 @@ ) where {N} = begin # get logscales of different inputs - logscales_inputs = map(getlogscale, messages[2]) + logscales_inputs = map( + m -> getlogscale(ReactiveMP.getannotations(m)), messages[2] + ) # get logscales of Categorical/Bernoulli - logscales_switch = getlogscale(messages[1]) .+ log.(probvec(m_switch)) + logscales_switch = + getlogscale(ReactiveMP.getannotations(messages[1])) .+ + log.(probvec(m_switch)) # compute logscales of individual components logscales = logscales_inputs .+ logscales_switch diff --git a/src/rules/mixture/switch.jl b/src/rules/mixture/switch.jl index 789e1dcf7..a373eb620 100644 --- a/src/rules/mixture/switch.jl +++ b/src/rules/mixture/switch.jl @@ -6,7 +6,16 @@ # `messages` are available from the `@rule` macro itself logscales = map( input -> getlogscale( - multiply_messages(GenericProd(), messages[1], input) + ReactiveMP.getannotations( + compute_product_of_two_messages( + ReactiveMP.randomvar(; label = :mixture_switch_rule), + ReactiveMP.MessageProductContext(; + annotations = (LogScaleAnnotations(),) + ), + messages[1], + input, + ), + ), ), messages[2], ) diff --git a/src/rules/multiplication/in.jl b/src/rules/multiplication/in.jl index 8e6759290..590dfadb7 100644 --- a/src/rules/multiplication/in.jl +++ b/src/rules/multiplication/in.jl @@ -159,7 +159,10 @@ end meta::Union{<:AbstractCorrectionStrategy, Nothing}, ) = begin return @call_rule typeof(*)(:in, Marginalisation) ( - m_A = m_out, m_out = m_A, meta = meta, addons = getaddons() + m_A = m_out, + m_out = m_A, + meta = meta, + annotations = getannotations(), ) # symmetric rule end @@ -175,6 +178,6 @@ end m_A = PointMass(mean(m_A).λ), m_out = m_out, meta = meta, - addons = getaddons(), + annotations = getannotations(), ) # dispatch to real * normal end diff --git a/src/rules/multiplication/out.jl b/src/rules/multiplication/out.jl index ac6d67bac..2dcd56bb6 100644 --- a/src/rules/multiplication/out.jl +++ b/src/rules/multiplication/out.jl @@ -18,7 +18,7 @@ end meta::Union{<:AbstractCorrectionStrategy, Nothing}, ) = begin return @call_rule typeof(*)(:out, Marginalisation) ( - m_A = m_in, m_in = m_A, meta = meta, addons = getaddons() + m_A = m_in, m_in = m_A, meta = meta, annotations = getannotations() ) # symmetric rule end @@ -41,7 +41,7 @@ end meta::Union{<:AbstractCorrectionStrategy, Nothing}, ) where {F <: NormalDistributionsFamily} = begin return @call_rule typeof(*)(:out, Marginalisation) ( - m_A = m_in, m_in = m_A, meta = meta, addons = getaddons() + m_A = m_in, m_in = m_A, meta = meta, annotations = getannotations() ) # symmetric rule end @@ -81,7 +81,7 @@ end meta::Union{<:AbstractCorrectionStrategy, Nothing}, ) = begin return @call_rule typeof(*)(:out, Marginalisation) ( - m_A = m_in, m_in = m_A, meta = meta, addons = getaddons() + m_A = m_in, m_in = m_A, meta = meta, annotations = getannotations() ) # symmetric rule end @@ -138,7 +138,7 @@ end meta::Union{<:AbstractCorrectionStrategy, Nothing}, ) = begin return @call_rule typeof(*)(:out, Marginalisation) ( - m_A = m_in, m_in = m_A, meta = meta, addons = getaddons() + m_A = m_in, m_in = m_A, meta = meta, annotations = getannotations() ) # symmetric rule end @@ -154,7 +154,7 @@ end m_A = PointMass(mean(m_A).λ), m_in = m_in, meta = meta, - addons = getaddons(), + annotations = getannotations(), ) # dispatch to real * normal end diff --git a/src/rules/normal_mixture/switch.jl b/src/rules/normal_mixture/switch.jl index d03e8c7b1..a16a29175 100644 --- a/src/rules/normal_mixture/switch.jl +++ b/src/rules/normal_mixture/switch.jl @@ -22,7 +22,7 @@ function rule_nm_switch_k(::Type{Univariate}, q_out, m, p) AverageEnergy(), NormalMeanPrecision, Val{(:out, :μ, :τ)}(), - map((q) -> Marginal(q, false, false, nothing), (q_out, m, p)), + map((q) -> Marginal(q, false, false), (q_out, m, p)), nothing, ) end @@ -32,7 +32,7 @@ function rule_nm_switch_k(::Type{Multivariate}, q_out, m, p) AverageEnergy(), MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}(), - map((q) -> Marginal(q, false, false, nothing), (q_out, m, p)), + map((q) -> Marginal(q, false, false), (q_out, m, p)), nothing, ) end diff --git a/src/score/node.jl b/src/score/node.jl index b5426202f..f9c001cc3 100644 --- a/src/score/node.jl +++ b/src/score/node.jl @@ -9,8 +9,7 @@ function score( ::FactorBoundFreeEnergy, node::AbstractFactorNode, meta, - skip_strategy, - scheduler, + stream_postprocessors, ) where {T <: CountingReal} return score( T, @@ -18,8 +17,7 @@ function score( sdtype(node), node, collect_meta(functionalform(node), meta), - skip_strategy, - scheduler, + stream_postprocessors, ) end @@ -31,14 +29,11 @@ function score( ::Deterministic, node::AbstractFactorNode, meta, - skip_strategy, - scheduler, + stream_postprocessors, ) where {T <: CountingReal} - fnstream = let skip_strategy = skip_strategy, scheduler = scheduler + fnstream = (interface) -> - apply_skip_filter(messagein(interface), skip_strategy) |> - schedule_on(scheduler) - end + get_stream_of_inbound_messages(interface) |> skip_initial() tinterfaces = Tuple(getinterfaces(node)) stream = combineLatest(map(fnstream, tinterfaces), PushNew()) @@ -67,13 +62,17 @@ function score( ), false, false, - nothing, ) return convert(T, -score(DifferentialEntropy(), marginal)) end end - return stream |> map(T, mapping) + stream_of_scores = stream |> map(T, mapping) + stream_of_scores = postprocess_stream_of_scores( + stream_postprocessors, stream_of_scores + ) + + return stream_of_scores end ## Stochastic mapping @@ -84,16 +83,13 @@ function score( ::Stochastic, node::AbstractFactorNode, meta, - skip_strategy, - scheduler, + stream_postprocessors, ) where {T <: CountingReal} - fnstream = let skip_strategy = skip_strategy, scheduler = scheduler + fnstream = (localmarginal) -> - apply_skip_filter(getmarginal(localmarginal), skip_strategy) |> - schedule_on(scheduler) - end + get_stream_of_marginals(localmarginal) |> skip_initial() - localmarginals = getmarginals(getlocalclusters(node)) + localmarginals = get_node_local_marginals(getlocalclusters(node)) stream = combineLatest(map(fnstream, localmarginals), PushNew()) mapping = @@ -107,5 +103,10 @@ function score( end end - return stream |> map(T, mapping) + stream_of_scores = stream |> map(T, mapping) + stream_of_scores = postprocess_stream_of_scores( + stream_postprocessors, stream_of_scores + ) + + return stream_of_scores end diff --git a/src/score/score.jl b/src/score/score.jl index 9de3563b1..2a61177c1 100644 --- a/src/score/score.jl +++ b/src/score/score.jl @@ -42,8 +42,7 @@ function score( let is_joint_clamped = is_clamped(joint), is_joint_initial = is_initial(joint) - (data) -> - Marginal(data, is_joint_clamped, is_joint_initial, nothing) + (data) -> Marginal(data, is_joint_clamped, is_joint_initial) end mod_marginals = TupleTools.insertat( @@ -89,9 +88,7 @@ function score(::DifferentialEntropy, marginal::Marginal{<:NamedTuple}) (data) -> score( DifferentialEntropy(), - Marginal( - data, is_marginal_clamped, is_marginal_initial, nothing - ), + Marginal(data, is_marginal_clamped, is_marginal_initial), ) end diff --git a/src/score/variable.jl b/src/score/variable.jl index 65af8b109..14dec72cf 100644 --- a/src/score/variable.jl +++ b/src/score/variable.jl @@ -6,20 +6,22 @@ function score( ::Type{T}, ::VariableBoundEntropy, variable::RandomVariable, - skip_strategy, - scheduler, + stream_postprocessors, ) where {T <: CountingReal} mapping = let d = degree(variable) (marginal) -> begin # The entropy of point masses is not finite - # In this case we treat them as clamped variables, such that we should multiply + # In this case we treat them as clamped variables, such that we should multiply # their influence on `d` instead of `d - 1` scaling = !ispointmass(marginal) ? (d - 1) : d entropy = convert(T, score(DifferentialEntropy(), marginal)) return scaling * entropy end end - return getmarginal(variable, skip_strategy) |> - schedule_on(scheduler) |> - map(T, mapping) + stream_of_scores = + get_stream_of_marginals(variable) |> skip_initial() |> map(T, mapping) + stream_of_scores = postprocess_stream_of_scores( + stream_postprocessors, stream_of_scores + ) + return stream_of_scores end diff --git a/src/variable.jl b/src/variable.jl new file mode 100644 index 000000000..0d7f89ab8 --- /dev/null +++ b/src/variable.jl @@ -0,0 +1,152 @@ +""" + AbstractVariable + +An abstract supertype for all variable types in the factor graph. +Concrete subtypes include: +- [`ReactiveMP.RandomVariable`](@ref) +- [`ReactiveMP.ConstVariable`](@ref) +- [`ReactiveMP.DataVariable`](@ref). +""" +abstract type AbstractVariable end + +Base.broadcastable(v::AbstractVariable) = Ref(v) + +# Helper functions + +""" + ReactiveMP.degree(variable) + +Returns the number of factor nodes connected to `variable`, i.e. the number of message streams. +""" +function degree end + +""" + ReactiveMP.israndom(variable) + ReactiveMP.israndom(variables::AbstractArray) + +Returns `true` if `variable` is a [`ReactiveMP.RandomVariable`](@ref). +For an array, returns `true` only if all elements are random variables. +""" +function israndom end + +""" + ReactiveMP.isdata(variable) + ReactiveMP.isdata(variables::AbstractArray) + +Returns `true` if `variable` is a [`DataVariable`](@ref). +For an array, returns `true` only if all elements are data variables. +""" +function isdata end + +""" + ReactiveMP.isconst(variable) + ReactiveMP.isconst(variables::AbstractArray) + +Returns `true` if `variable` is a [`ReactiveMP.ConstVariable`](@ref). +For an array, returns `true` only if all elements are const variables. +""" +function isconst end + +israndom(v::AbstractArray{<:AbstractVariable}) = all(israndom, v) +isdata(v::AbstractArray{<:AbstractVariable}) = all(isdata, v) +isconst(v::AbstractArray{<:AbstractVariable}) = all(isconst, v) + +""" + ReactiveMP.create_new_stream_of_inbound_messages!(variable) + +Allocates a new per-connection [`ReactiveMP.MessageObservable`](@ref) for `variable` and registers it as an additional inbound message slot. +Returns a tuple `(observable, index)` where `observable` is the newly created stream and `index` is its position in the variable's internal `input_messages` collection. + +Called once per factor node connection at graph construction time. The returned `observable` is stored as the *outbound* message stream of the corresponding [`ReactiveMP.NodeInterface`](@ref) — it is the outbound message from the node's perspective and the inbound message from the variable's perspective. All streams are unconnected (lazy) until [`ReactiveMP.activate!`](@ref) is called. + +For [`ReactiveMP.ConstVariable`](@ref) the same shared observable is returned for every connection; no per-connection slot is allocated. + +See also: [`ReactiveMP.MessageObservable`](@ref), [`ReactiveMP.NodeInterface`](@ref) +""" +function create_new_stream_of_inbound_messages! end + +""" + ReactiveMP.get_stream_of_predictions(variable) + +Returns the prediction observable stream for `variable`. +For [`DataVariable`](@ref), the prediction is the product of all inbound messages. +See also [`ReactiveMP.set_stream_of_predictions!`](@ref). +""" +function get_stream_of_predictions end + +""" + ReactiveMP.set_stream_of_predictions!(variable, stream) + +Connects `stream` as the prediction observable for `variable`. +See also [`ReactiveMP.get_stream_of_predictions`](@ref). +""" +function set_stream_of_predictions! end + +""" + ReactiveMP.get_stream_of_marginals(variable) + +Returns the marginal observable stream for `variable`. +See also [`ReactiveMP.set_stream_of_marginals!`](@ref), [`ReactiveMP.set_initial_marginal!`](@ref). +""" +function get_stream_of_marginals end + +""" + ReactiveMP.set_stream_of_marginals!(variable, stream) + +Connects `stream` as the marginal observable for `variable`. +See also [`ReactiveMP.get_stream_of_marginals`](@ref). +""" +function set_stream_of_marginals! end + +""" + ReactiveMP.set_initial_marginal!(variable, marginal) + ReactiveMP.set_initial_marginal!(variables::AbstractArray, marginals) + +Sets the initial marginal belief for `variable` by pushing `marginal` as an initial (non-clamped) value +into [`ReactiveMP.get_stream_of_marginals`](@ref). For arrays, applies element-wise. +See also [`ReactiveMP.set_initial_message!`](@ref). +""" +function set_initial_marginal!(variable::AbstractVariable, marginal) + set_initial_marginal!(get_stream_of_marginals(variable), marginal) +end + +set_initial_marginal!(variables::AbstractArray{<:AbstractVariable}, marginal::PointMass) = _set_initial_marginal!(Base.HasLength(), variables, Iterators.repeated(marginal, length(variables))) +set_initial_marginal!(variables::AbstractArray{<:AbstractVariable}, marginal::Distribution) = _set_initial_marginal!(Base.HasLength(), variables, Iterators.repeated(marginal, length(variables))) +set_initial_marginal!(variables::AbstractArray{<:AbstractVariable}, marginals) = _set_initial_marginal!(Base.IteratorSize(marginals), variables, marginals) + +function _set_initial_marginal!( + ::Base.IteratorSize, variables::AbstractArray{<:AbstractVariable}, marginals +) + @assert length(variables) == length(marginals) "Variables $(variables) and marginals $(marginals) should have the same length" + foreach(zip(variables, marginals)) do (variable, marginal) + set_initial_marginal!(variable, marginal) + end +end + +""" + ReactiveMP.set_initial_message!(variable, message) + ReactiveMP.set_initial_message!(variables::AbstractArray, messages) + +Sets the initial message for all interfaces of `variable` by pushing `message` into each outbound message stream. +For arrays, applies element-wise. See also [`ReactiveMP.set_initial_marginal!`](@ref). +""" +function set_initial_message!(variable::AbstractVariable, message) + for i in 1:degree(variable) + set_initial_message!( + get_stream_of_outbound_messages(variable, i), message + ) + end +end + +set_initial_message!(variables::AbstractArray{<:AbstractVariable}, message::PointMass) = _set_initial_message!(Base.HasLength(), variables, Iterators.repeated(message, length(variables))) +set_initial_message!(variables::AbstractArray{<:AbstractVariable}, message::Distribution) = _set_initial_message!(Base.HasLength(), variables, Iterators.repeated(message, length(variables))) +set_initial_message!(variables::AbstractArray{<:AbstractVariable}, messages) = _set_initial_message!(Base.IteratorSize(messages), variables, messages) + +function _set_initial_message!( + ::Base.IteratorSize, variables::AbstractArray{<:AbstractVariable}, messages +) + @assert length(variables) == length(messages) "Variables $(variables) and messages $(messages) should have the same length" + foreach(zip(variables, messages)) do (variable, message) + set_initial_message!(variable, message) + end +end diff --git a/src/variables/constant.jl b/src/variables/constant.jl index d52528092..c02cf4f04 100644 --- a/src/variables/constant.jl +++ b/src/variables/constant.jl @@ -1,21 +1,36 @@ export constvar, ConstVariable +""" + ConstVariable <: AbstractVariable + +Represents a constant (clamped) variable in the factor graph. The value is fixed at creation time and +wrapped in a `PointMass` distribution. Messages and marginals from this variable are always marked as clamped. +Use [`constvar`](@ref) to create an instance. + +See also: [`ReactiveMP.RandomVariable`](@ref), [`ReactiveMP.DataVariable`](@ref) +""" mutable struct ConstVariable <: AbstractVariable marginal :: MarginalObservable messageout :: MessageObservable constant :: Any nconnected :: Int + label :: Any end -function ConstVariable(constant) +function ConstVariable(constant; label = nothing) marginal = MarginalObservable() - connect!(marginal, of(Marginal(PointMass(constant), true, false, nothing))) + connect!(marginal, of(Marginal(PointMass(constant), true, false))) messageout = MessageObservable(AbstractMessage) - connect!(messageout, of(Message(PointMass(constant), true, false, nothing))) - return ConstVariable(marginal, messageout, constant, 0) + connect!(messageout, of(Message(PointMass(constant), true, false))) + return ConstVariable(marginal, messageout, constant, 0, label) end -constvar(constant) = ConstVariable(constant) +""" + constvar(constant; label = nothing) + +Creates a new [`ReactiveMP.ConstVariable`](@ref) with the given `constant` value and an optional `label` for identification. +""" +constvar(constant; label = nothing) = ConstVariable(constant; label = label) degree(constvar::ConstVariable) = constvar.nconnected getconst(constvar::ConstVariable) = constvar.constant @@ -27,19 +42,25 @@ isdata(::AbstractArray{<:ConstVariable}) = false isconst(::ConstVariable) = true isconst(::AbstractArray{<:ConstVariable}) = true -function create_messagein!(constvar::ConstVariable) +get_stream_of_marginals(constvar::ConstVariable) = constvar.marginal +get_stream_of_predictions(constvar::ConstVariable) = constvar.marginal + +set_stream_of_marginals!(constvar::ConstVariable, stream) = error( + "It is not possible to set a stream of marginals for a `ConstVariable`" +) +set_stream_of_predictions!(constvar::ConstVariable, stream) = error( + "It is not possible to set a stream of predictions for a `ConstVariable`" +) + +function create_new_stream_of_inbound_messages!(constvar::ConstVariable) constvar.nconnected += 1 return constvar.messageout, 1 end -function messagein(::ConstVariable, ::Int) +function get_stream_of_inbound_messages(::ConstVariable, ::Int) error("ConstVariable does not save inbound messages.") end -function messageout(constvar::ConstVariable, ::Int) +function get_stream_of_outbound_messages(constvar::ConstVariable, ::Int) return constvar.messageout end - -_getmarginal(constvar::ConstVariable) = constvar.marginal -_setmarginal!(::ConstVariable, observable) = error("It is not possible to set a marginal stream for `ConstVariable`") -_makemarginal(::ConstVariable) = error("It is not possible to make marginal stream for `ConstVariable`") diff --git a/src/variables/data.jl b/src/variables/data.jl index fb3cb4656..82c32ce3a 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -1,13 +1,22 @@ -export datavar, DataVariable, update!, DataVariableActivationOptions +export datavar, DataVariable, new_observation!, DataVariableActivationOptions +""" + DataVariable <: AbstractVariable + +Represents an observed variable in the factor graph. Unlike [`ReactiveMP.ConstVariable`](@ref), the data is not fixed +at creation time and can be updated later via [`ReactiveMP.new_observation!`](@ref). Use [`datavar`](@ref) to create an instance. + +See also: [`ReactiveMP.RandomVariable`](@ref), [`ReactiveMP.ConstVariable`](@ref) +""" mutable struct DataVariable{M, P} <: AbstractVariable input_messages :: Vector{MessageObservable{AbstractMessage}} marginal :: MarginalObservable messageout :: M prediction :: P + label :: Any end -function DataVariable() +function DataVariable(; label = nothing) messageout = RecentSubject(Message) marginal = MarginalObservable() prediction = MarginalObservable() @@ -16,10 +25,16 @@ function DataVariable() marginal, messageout, prediction, + label, ) end -datavar() = DataVariable() +""" + datavar(; label = nothing) + +Creates a new [`ReactiveMP.DataVariable`](@ref) with an optional `label` for identification. +""" +datavar(; label = nothing) = DataVariable(; label = label) degree(datavar::DataVariable) = length(datavar.input_messages) @@ -30,20 +45,41 @@ isdata(::AbstractArray{<:DataVariable}) = true isconst(::DataVariable) = false isconst(::AbstractArray{<:DataVariable}) = false -function create_messagein!(datavar::DataVariable) - messagein = MessageObservable(AbstractMessage) - push!(datavar.input_messages, messagein) - return messagein, length(datavar.input_messages) +get_stream_of_marginals(datavar::DataVariable) = datavar.marginal +get_stream_of_predictions(datavar::DataVariable) = datavar.prediction + +set_stream_of_marginals!(datavar::DataVariable, stream) = connect!( + datavar.marginal, stream +) +set_stream_of_predictions!(datavar::DataVariable, stream) = connect!( + datavar.prediction, stream +) + +function create_new_stream_of_inbound_messages!(datavar::DataVariable) + new_stream_of_inbound_messages = MessageObservable(AbstractMessage) + push!(datavar.input_messages, new_stream_of_inbound_messages) + return new_stream_of_inbound_messages, length(datavar.input_messages) end -function messagein(datavar::DataVariable, index::Int) +function get_stream_of_inbound_messages(datavar::DataVariable, index::Int) return datavar.input_messages[index] end -function messageout(datavar::DataVariable, ::Int) +function get_stream_of_outbound_messages(datavar::DataVariable, ::Int) return datavar.messageout end +""" + DataVariableActivationOptions + +Collects all configuration needed to activate a [`ReactiveMP.DataVariable`](@ref). Passed to [`ReactiveMP.activate!(::DataVariable, ::DataVariableActivationOptions)`](@ref). + +Fields: +- `prediction::Bool` — if `true`, a prediction stream is built during activation as the product of all inbound (backward) messages +- `linked::Bool` — if `true`, the variable's observation stream is driven by a deterministic transformation of other variables' marginals rather than by direct [`ReactiveMP.new_observation!`](@ref) calls +- `transform` — the transformation function applied to the linked variables' marginals (used only when `linked = true`) +- `args` — the collection of linked variables or constants whose marginals are combined (used only when `linked = true`) +""" struct DataVariableActivationOptions prediction::Bool linked::Bool @@ -55,11 +91,39 @@ DataVariableActivationOptions() = DataVariableActivationOptions( false, false, nothing, nothing ) +""" + ReactiveMP.activate!(datavar::DataVariable, options::DataVariableActivationOptions) + +Wires all reactive streams of a [`ReactiveMP.DataVariable`](@ref) into the factor graph. + +Activation proceeds in up to three steps: + +1. **Prediction** — if `options.prediction` is `true`, a prediction stream is built via `collectLatest` over all inbound (backward) [`ReactiveMP.MessageObservable`](@ref)s: once all backward messages have emitted and again when all of them update, their product is emitted as the model's prior expectation for this variable. + +2. **Linked variables** — if `options.linked` is `true`, a subscription is created over a transformed combination of other variables' marginals. Each update is forwarded automatically to [`ReactiveMP.new_observation!`](@ref), making the data variable's observation a deterministic function of those variables. + +3. **Marginal** — always wired: the marginal stream is `messageout |> map(as_marginal)`, so the marginal always equals the most recently pushed observation. + +See also: [`ReactiveMP.DataVariableActivationOptions`](@ref), [`ReactiveMP.activate!(::RandomVariable, ::RandomVariableActivationOptions)`](@ref) +""" function activate!( datavar::DataVariable, options::DataVariableActivationOptions ) if options.prediction - _setprediction!(datavar, _makeprediction(datavar)) + # if the prediction is requested, we instantiate the stream of predictions + # as the product of all inbound messages to the datavar + # otherwise the stream of predictions is empty + stream_of_predictions = collectLatest( + AbstractMessage, + Marginal, + datavar.input_messages, + (messages) -> as_marginal( + compute_product_of_messages( + datavar, MessageProductContext(), messages + ), + ), + ) + set_stream_of_predictions!(datavar, stream_of_predictions) end if options.linked @@ -73,56 +137,55 @@ function activate!( return __apply_link(f, getrecent.(args)) end) # This subscription should unsubscribe automatically when the linked `datavar`s complete - subscribe!(linkstream, (val) -> update!(datavar, val)) + subscribe!(linkstream, (val) -> new_observation!(datavar, val)) end # The marginal stream is always the same as the message out - connect!(datavar.marginal, datavar.messageout |> map(Marginal, as_marginal)) + # but converted to Marginal with the as_marginal function + stream_of_marginals = datavar.messageout |> map(Marginal, as_marginal) + set_stream_of_marginals!(datavar, stream_of_marginals) return nothing end -__link_getmarginal(constant) = of( - Marginal(PointMass(constant), true, false, nothing) -) -__link_getmarginal(l::AbstractVariable) = getmarginal(l, IncludeAll()) -__link_getmarginal(l::AbstractArray{<:AbstractVariable}) = getmarginals( - l, IncludeAll() +__link_getmarginal(constant) = of(Marginal(PointMass(constant), true, false)) +__link_getmarginal(l::AbstractVariable) = get_stream_of_marginals(l) +__link_getmarginal(l::AbstractArray{<:AbstractVariable}) = collectLatest( + map(get_stream_of_marginals, l) ) __apply_link(f::F, args) where {F} = __apply_link(f, getdata.(args)) __apply_link(f::F, args::NTuple{N, PointMass}) where {F, N} = f(mean.(args)...) -_getmarginal(datavar::DataVariable) = datavar.marginal -_setmarginal!(::DataVariable, observable) = error("It is not possible to set a marginal stream for `DataVariable`") -_makemarginal(::DataVariable) = error("It is not possible to make marginal stream for `DataVariable`") +""" + new_observation!(datavar::DataVariable, data) + new_observation!(datavars::AbstractArray{<:DataVariable}, data::AbstractArray) -update!(datavar::DataVariable, data) = update!(datavar, PointMass(data)) -update!(datavar::DataVariable, data::PointMass) = next!(datavar.messageout, Message(data, false, false, nothing)) -update!(datavar::DataVariable, ::Missing) = next!(datavar.messageout, Message(missing, false, false, nothing)) +Provides a new observation to a [`ReactiveMP.DataVariable`](@ref) (or an array of data variables). +The `data` is wrapped in a `PointMass` distribution and pushed as a new message. +Pass `missing` to indicate that the observation is not available. +""" +new_observation!(datavar::DataVariable, data) = new_observation!( + datavar, PointMass(data) +) +new_observation!(datavar::DataVariable, data::PointMass) = next!(datavar.messageout, Message(data, false, false)) +new_observation!(datavar::DataVariable, ::Missing) = next!(datavar.messageout, Message(missing, false, false)) -function update!(datavars::AbstractArray{<:DataVariable}, data::AbstractArray) +function new_observation!( + datavars::AbstractArray{<:DataVariable}, data::AbstractArray +) @assert size(datavars) === size(data) """ - Invalid `update!` call: size of datavar array and data must match: `variables` has size $(size(datavars)) and `data` has size $(size(data)). + Invalid `new_observation!` call: size of datavar array and data must match: `variables` has size $(size(datavars)) and `data` has size $(size(data)). """ foreach(zip(datavars, data)) do (var, d) - update!(var, d) + new_observation!(var, d) end end -function update!(datavars::AbstractArray{<:DataVariable}, data::Missing) +function new_observation!( + datavars::AbstractArray{<:DataVariable}, data::Missing +) foreach(datavars) do var - update!(var, data) + new_observation!(var, data) end end - -marginal_prod_fn(datavar::DataVariable) = marginal_prod_fn( - FoldLeftProdStrategy(), - GenericProd(), - UnspecifiedFormConstraint(), - FormConstraintCheckLast(), -) - -_getprediction(datavar::DataVariable) = datavar.prediction -_setprediction!(datavar::DataVariable, observable) = connect!(_getprediction(datavar), observable) -_makeprediction(datavar::DataVariable) = collectLatest(AbstractMessage, Marginal, datavar.input_messages, marginal_prod_fn(datavar)) diff --git a/src/variables/random.jl b/src/variables/random.jl index 827260b6a..fda0718ec 100644 --- a/src/variables/random.jl +++ b/src/variables/random.jl @@ -2,20 +2,41 @@ export randomvar, RandomVariable, RandomVariableActivationOptions ## Random variable implementation +""" + RandomVariable <: AbstractVariable + +Represents a latent (unobserved) variable in the factor graph. Random variables collect incoming and outgoing messages +from connected factor nodes and maintain a marginal belief. Use [`randomvar`](@ref) to create an instance. + +See also: [`ReactiveMP.ConstVariable`](@ref), [`ReactiveMP.DataVariable`](@ref) +""" mutable struct RandomVariable <: AbstractVariable - input_messages::Vector{MessageObservable{AbstractMessage}} - output_messages::Vector{MessageObservable{Message}} - marginal::MarginalObservable + input_messages :: Vector{MessageObservable{AbstractMessage}} + output_messages :: Vector{MessageObservable{Message}} + marginal :: MarginalObservable + label :: Any end -function randomvar() +""" + randomvar(; label = nothing) + +Creates a new [`ReactiveMP.RandomVariable`](@ref) with an optional `label` for identification. +""" +function randomvar(; label = nothing) return RandomVariable( Vector{MessageObservable{AbstractMessage}}(), Vector{MessageObservable{Message}}(), MarginalObservable(), + label, ) end +""" + ReactiveMP.degree(randomvar::RandomVariable) + +Returns the number of factor nodes connected to `randomvar`, equal to the length of its inbound message streams collection. +See also [`ReactiveMP.degree`](@ref). +""" degree(randomvar::RandomVariable) = length(randomvar.input_messages) israndom(::RandomVariable) = true @@ -25,43 +46,65 @@ isdata(::AbstractArray{<:RandomVariable}) = false isconst(::RandomVariable) = false isconst(::AbstractArray{<:RandomVariable}) = false -function create_messagein!(randomvar::RandomVariable) - messagein = MessageObservable(AbstractMessage) - push!(randomvar.input_messages, messagein) - return messagein, length(randomvar.input_messages) +get_stream_of_marginals(randomvar::RandomVariable) = randomvar.marginal +get_stream_of_predictions(randomvar::RandomVariable) = randomvar.marginal + +set_stream_of_marginals!(randomvar::RandomVariable, stream) = connect!( + randomvar.marginal, stream +) +set_stream_of_predictions!(randomvar::RandomVariable, stream) = error( + "It is not possible to set a stream of predictions for `RandomVariable`" +) + +function create_new_stream_of_inbound_messages!(randomvar::RandomVariable) + new_stream_of_inbound_messages = MessageObservable(AbstractMessage) + push!(randomvar.input_messages, new_stream_of_inbound_messages) + return new_stream_of_inbound_messages, length(randomvar.input_messages) end -function messagein(randomvar::RandomVariable, index::Int) +function get_stream_of_inbound_messages(randomvar::RandomVariable, index::Int) return randomvar.input_messages[index] end -function messageout(randomvar::RandomVariable, index::Int) +function get_stream_of_outbound_messages(randomvar::RandomVariable, index::Int) return randomvar.output_messages[index] end -const DefaultMessageProdFn = messages_prod_fn( - FoldLeftProdStrategy(), - GenericProd(), - UnspecifiedFormConstraint(), - FormConstraintCheckLast(), -) -const DefaultMarginalProdFn = marginal_prod_fn( - FoldLeftProdStrategy(), - GenericProd(), - UnspecifiedFormConstraint(), - FormConstraintCheckLast(), -) - -struct RandomVariableActivationOptions{S, F, M} - scheduler::S - message_prod_fn::F - marginal_prod_fn::M +""" + RandomVariableActivationOptions + +Collects all configuration needed to activate a [`ReactiveMP.RandomVariable`](@ref). Passed to [`ReactiveMP.activate!(::RandomVariable, ::RandomVariableActivationOptions)`](@ref). + +Fields: +- `stream_postprocessor` — optional stream postprocessor applied to every created stream (see [`ReactiveMP.AbstractStreamPostprocessor`](@ref)) +- `prod_context_for_message_computation` — a [`ReactiveMP.MessageProductContext`](@ref) used when computing outbound messages (product of all-but-one inbound messages in the `EqualityChain`) +- `prod_context_for_marginal_computation` — a [`ReactiveMP.MessageProductContext`](@ref) used when computing the marginal (product of all inbound messages) +""" +struct RandomVariableActivationOptions{ + S, F <: MessageProductContext, M <: MessageProductContext +} + stream_postprocessor::S + prod_context_for_message_computation::F + prod_context_for_marginal_computation::M end RandomVariableActivationOptions() = RandomVariableActivationOptions( - AsapScheduler(), DefaultMessageProdFn, DefaultMarginalProdFn + nothing, MessageProductContext(), MessageProductContext() ) +""" + ReactiveMP.activate!(randomvar::RandomVariable, options::RandomVariableActivationOptions) + +Wires all reactive streams of a [`ReactiveMP.RandomVariable`](@ref) into the factor graph. + +Activation proceeds in two steps: + +1. **Outbound messages** — resizes `output_messages` to match the number of connected nodes (the [`ReactiveMP.degree`](@ref)). If degree > 1, an `EqualityChain` is constructed: for each edge i the outbound message stream emits the product of all inbound messages *except* the one arriving on edge i, implementing the standard sum-product or variational update. If degree == 1 (a leaf variable), the single outbound stream is connected to `never()` because there are no other messages to multiply. + +2. **Marginal** — `collectLatest` is called over all inbound [`ReactiveMP.MessageObservable`](@ref)s. It waits for all inbound messages to have emitted at least once, then emits the product as a new [`Marginal`](@ref) via [`ReactiveMP.set_stream_of_marginals!`](@ref), and re-emits only once all inbound messages have each updated again. + +See also: [`ReactiveMP.RandomVariableActivationOptions`](@ref), [`ReactiveMP.activate!(::DataVariable, ::DataVariableActivationOptions)`](@ref) +""" function activate!( randomvar::RandomVariable, options::RandomVariableActivationOptions ) @@ -76,8 +119,12 @@ function activate!( if length(randomvar.input_messages) > 1 chain = EqualityChain( randomvar.input_messages, - schedule_on(options.scheduler), - options.message_prod_fn, + options.stream_postprocessor, + (messages) -> compute_product_of_messages( + randomvar, + options.prod_context_for_message_computation, + messages, + ), ) initialize!(chain, outputmsgs) elseif length(randomvar.input_messages) == 1 @@ -92,25 +139,44 @@ function activate!( ) end - _setmarginal!(randomvar, _makemarginal(randomvar, options)) + stream_of_marginals = collectLatest( + AbstractMessage, + Marginal, + randomvar.input_messages, + (messages) -> + _compute_marginal_from_messages(randomvar, options, messages), + reset_vstatus, + ) + stream_of_marginals = postprocess_stream_of_marginals( + options.stream_postprocessor, stream_of_marginals + ) + + set_stream_of_marginals!(randomvar, stream_of_marginals) return nothing end -_getmarginal(randomvar::RandomVariable) = randomvar.marginal -_setmarginal!(randomvar::RandomVariable, observable) = connect!( - _getmarginal(randomvar), observable +function _compute_marginal_from_messages( + randomvar::RandomVariable, + options::RandomVariableActivationOptions, + messages, ) -_makemarginal( - randomvar::RandomVariable, options::RandomVariableActivationOptions -) = begin - return collectLatest( - AbstractMessage, - Marginal, - randomvar.input_messages, - options.marginal_prod_fn, - reset_vstatus, + context = options.prod_context_for_marginal_computation + span_id = generate_span_id(context.callbacks) + invoke_callback( + context.callbacks, + BeforeMarginalComputationEvent(randomvar, context, messages, span_id), + ) + result = as_marginal( + compute_product_of_messages(randomvar, context, messages) + ) + invoke_callback( + context.callbacks, + AfterMarginalComputationEvent( + randomvar, context, messages, result, span_id + ), ) + return result end # Reset consumption of the combination of inbound messages if the result of the computations is `is_initial` @@ -121,7 +187,7 @@ function reset_vstatus(wrapper, value) # The logic here is that if the result of the computation is `is_initial` we should reuse the arguments for the next computation # This may happen, when we initialize messages on the graph, which in turn also initializes marginals (implicitly) # if this happens, the inference cannot proceed further, since the initial messages have been consumed - # This also prevents weird FE behaviour, when it "maximizes" the FE value, but converges to a minimum value + # This also prevents weird FE behaviour, when it "maximizes" the FE value, but converges to a minimum value if is_initial(value) Rocket.fill_vstatus!(wrapper, true) end diff --git a/src/variables/variable.jl b/src/variables/variable.jl deleted file mode 100644 index 9eedb6cb6..000000000 --- a/src/variables/variable.jl +++ /dev/null @@ -1,114 +0,0 @@ -export AbstractVariable, degree -export FoldLeftProdStrategy, FoldRightProdStrategy, CustomProdStrategy -export getprediction, getpredictions, getmarginal, getmarginals -export setmarginal!, setmarginals!, setmessage!, setmessages! - -using Rocket - -abstract type AbstractVariable end - -## Base interface extensions - -Base.broadcastable(v::AbstractVariable) = Ref(v) - -## Messages to Marginal product strategies - -struct FoldLeftProdStrategy end -struct FoldRightProdStrategy end - -struct CustomProdStrategy{F} - prod_callback_generator::F -end - -""" - messages_prod_fn(strategy, prod_constraint, form_constraint, form_check_strategy) - -Returns a suitable prod computation function for a given strategy and constraints -""" -function messages_prod_fn end - -messages_prod_fn(::FoldLeftProdStrategy, prod_constraint, form_constraint, form_check_strategy) = prod_foldl_reduce(prod_constraint, form_constraint, form_check_strategy) -messages_prod_fn(::FoldRightProdStrategy, prod_constraint, form_constraint, form_check_strategy) = prod_foldr_reduce(prod_constraint, form_constraint, form_check_strategy) -messages_prod_fn(strategy::CustomProdStrategy, prod_constraint, form_constraint, form_check_strategy) = strategy.prod_callback_generator(prod_constraint, form_constraint, form_check_strategy) - -function marginal_prod_fn( - strategy, prod_constraint, form_constraint, form_check_strategy -) - return let prod_fn = messages_prod_fn( - strategy, prod_constraint, form_constraint, form_check_strategy - ) - return (messages) -> as_marginal(prod_fn(messages)) - end -end - -# Helper functions - -israndom(v::AbstractArray{<:AbstractVariable}) = all(israndom, v) -isdata(v::AbstractArray{<:AbstractVariable}) = all(isdata, v) -isconst(v::AbstractArray{<:AbstractVariable}) = all(isconst, v) - -# Getters - -getprediction(variable::AbstractVariable) = _getprediction(variable) -getpredictions(variables::AbstractArray{<:AbstractVariable}) = collectLatest(map(v -> getprediction(v), variables)) - -getmarginal(variable::AbstractVariable) = getmarginal(variable, SkipInitial()) -getmarginal(variable::AbstractVariable, skip_strategy::MarginalSkipStrategy) = apply_skip_filter(_getmarginal(variable), skip_strategy) - -getmarginals(variables::AbstractArray{<:AbstractVariable}) = getmarginals(variables, SkipInitial()) -getmarginals(variables::AbstractArray{<:AbstractVariable}, skip_strategy::MarginalSkipStrategy) = collectLatest(map(v -> getmarginal(v, skip_strategy), variables)) - -## Setters - -### Marginals - -setmarginal!(variable::AbstractVariable, marginal) = setmarginal!( - getmarginal(variable, IncludeAll()), marginal -) - -setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginal::PointMass) = _setmarginals!(Base.HasLength(), variables, Iterators.repeated(marginal, length(variables))) -setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginal::Distribution) = _setmarginals!(Base.HasLength(), variables, Iterators.repeated(marginal, length(variables))) -setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginals) = _setmarginals!(Base.IteratorSize(marginals), variables, marginals) - -function _setmarginals!( - ::Base.IteratorSize, variables::AbstractArray{<:AbstractVariable}, marginals -) - @assert length(variables) == length(marginals) "Variables $(variables) and marginals $(marginals) should have the same length" - foreach(zip(variables, marginals)) do (variable, marginal) - setmarginal!(variable, marginal) - end -end - -function _setmarginals!( - ::Any, variables::AbstractArray{<:AbstractVariable}, marginals -) - error( - "setmarginals!() failed. Default value is neither an iterable object nor a distribution.", - ) -end - -### Messages - -setmessage!(variable::AbstractVariable, index::Int, message) = setmessage!(messageout(variable, index), message) -setmessage!(variable::AbstractVariable, message) = foreach(i -> setmessage!(variable, i, message), 1:degree(variable)) - -setmessages!(variables::AbstractArray{<:AbstractVariable}, message::PointMass) = _setmessages!(Base.HasLength(), variables, Iterators.repeated(message, length(variables))) -setmessages!(variables::AbstractArray{<:AbstractVariable}, message::Distribution) = _setmessages!(Base.HasLength(), variables, Iterators.repeated(message, length(variables))) -setmessages!(variables::AbstractArray{<:AbstractVariable}, messages) = _setmessages!(Base.IteratorSize(messages), variables, messages) - -function _setmessages!( - ::Base.IteratorSize, variables::AbstractArray{<:AbstractVariable}, messages -) - @assert length(variables) == length(messages) "Variables $(variables) and messages $(messages) should have the same length" - foreach(zip(variables, messages)) do (variable, message) - setmessage!(variable, message) - end -end - -function _setmessages!( - ::Any, variables::AbstractArray{<:AbstractVariable}, marginals -) - error( - "setmessages!() failed. Default value is neither an iterable object nor a distribution.", - ) -end diff --git a/test/addons/debug_tests.jl b/test/addons/debug_tests.jl deleted file mode 100644 index 8033f9a20..000000000 --- a/test/addons/debug_tests.jl +++ /dev/null @@ -1,23 +0,0 @@ -@testitem "Debug addon" begin - using ExponentialFamily, BayesBase - import ReactiveMP: AddonDebug - - @testset "Creation" begin - addon = AddonDebug(x -> x == π) - - @test addon(π) - @test addon(3.14) == false - end - - @testset "Simple application and printing" begin - import ReactiveMP: multiply_addons - - addon = AddonDebug(x -> any(params(x) .== 3.14)) - @test_throws ErrorException multiply_addons( - addon, addon, NormalMeanVariance(0.0, 3.14), Missing(), Missing() - ) - @test multiply_addons( - addon, addon, NormalMeanVariance(0.0, 3.00), Missing(), Missing() - ) == addon - end -end diff --git a/test/addons/logscale_tests.jl b/test/addons/logscale_tests.jl deleted file mode 100644 index 9f76efd45..000000000 --- a/test/addons/logscale_tests.jl +++ /dev/null @@ -1,14 +0,0 @@ -@testitem "Logscale addon" begin - import ReactiveMP: AddonLogScale, AddonMemory, AddonDebug - - @testset "Error handling" begin - @test_throws "Log-scale addon is not available. Make sure to include AddonLogScale in the addons. Currently, log scale factors are only supported for very specific nodes and messages in sum-product updates. Extensions to variational message passing are not yet supported." getlogscale( - nothing - ) - end - - @testset "Simple application" begin - addon = AddonLogScale(0.0) - @test getlogscale(addon) == 0.0 - end -end diff --git a/test/addons/memory_tests.jl b/test/addons/memory_tests.jl deleted file mode 100644 index 51e00a6e6..000000000 --- a/test/addons/memory_tests.jl +++ /dev/null @@ -1,47 +0,0 @@ -@testitem "Memory addon" begin - using ExponentialFamily, BayesBase - - import ReactiveMP: AddonMemory - - @testset "Creation" begin - addon = AddonMemory() - - @test occursin("memory", string(addon)) - end - - @testset "Simple application and printing" begin - import ReactiveMP: MessageMapping, message_mapping_addon - - mapping = MessageMapping( - NormalMeanVariance, - Val(:out), - Marginalisation(), - Val((:x, :y)), - Val((:z, :k)), - "meta", - AddonMemory(), - nothing, - nothing, - ) - - messages = (Gamma(1.0, 1.0), NormalMeanVariance(0.0, 1.0)) - marginals = (PointMass(1.0), NormalMeanPrecision(0.0, 1.0)) - result = MvNormalMeanCovariance(ones(2), diageye(2)) - - addon = message_mapping_addon( - AddonMemory(), mapping, messages, marginals, result - ) - - displayed = repr(addon) - - @test occursin(r"node: ExponentialFamily.NormalMeanVariance", displayed) - @test occursin(r"interface: .*:out.*", displayed) - @test occursin(r"local constraint:.*Marginalisation()", displayed) - @test occursin(r"messages on.*(:x, :y).*edges", displayed) - @test occursin(repr(messages), displayed) - @test occursin(r"marginals on.*(:z, :k).*edges", displayed) - @test occursin(repr(marginals), displayed) - @test occursin(r"meta: meta", displayed) - @test occursin("result: $(repr(result))", displayed) - end -end diff --git a/test/addons_tests.jl b/test/addons_tests.jl deleted file mode 100644 index d2d2d8c08..000000000 --- a/test/addons_tests.jl +++ /dev/null @@ -1,57 +0,0 @@ - -@testitem "Addons" begin - using ReactiveMP, BayesBase, Distributions, ExponentialFamily - - using ReactiveMP: multiply_addons - - @testset "addonlogscale" begin - @testset "creation" begin - addon1 = AddonLogScale() - addon2 = AddonLogScale(2) - addon3 = AddonLogScale(3.0) - - @test addon1.logscale === nothing - @test addon2.logscale === 2 - @test addon3.logscale === 3.0 - end - - @testset "getlogscale" begin - message = Message(Normal(1, 0), false, false, (AddonLogScale(3),)) - marginal = Marginal(Normal(1, 0), false, false, (AddonLogScale(4.0),)) - - @test getlogscale(message) == 3 - @test getlogscale(marginal) == 4.0 - end - - @testset "multiply_addons" begin - left_addons = (AddonLogScale(5),) - right_addons = (AddonLogScale(6.0),) - new_dist = vague(Bernoulli) - left_dist = vague(Bernoulli) - right_dist = vague(Bernoulli) - - @test multiply_addons( - left_addons, right_addons, new_dist, left_dist, right_dist - ) == (AddonLogScale(11.0 - log(2)),) - @test multiply_addons( - AddonLogScale(5), - AddonLogScale(6.0), - new_dist, - left_dist, - right_dist, - ) == AddonLogScale(11.0 - log(2)) - @test multiply_addons( - AddonLogScale(5), nothing, new_dist, left_dist, missing - ) == AddonLogScale(5) - @test multiply_addons( - nothing, AddonLogScale(6.0), new_dist, missing, right_dist - ) == AddonLogScale(6.0) - @test multiply_addons( - nothing, nothing, new_dist, left_dist, missing - ) === nothing - @test multiply_addons( - nothing, nothing, new_dist, missing, right_dist - ) === nothing - end - end -end diff --git a/test/annotations/input_arguments_tests.jl b/test/annotations/input_arguments_tests.jl new file mode 100644 index 000000000..a2c4af471 --- /dev/null +++ b/test/annotations/input_arguments_tests.jl @@ -0,0 +1,370 @@ +@testmodule RuleInputArgumentsTestUtils begin + import ReactiveMP: + AnnotationDict, + InputArgumentsAnnotations, + RuleInputArgumentsRecord, + ProductInputArgumentsRecord + + struct MockMapping + name::Symbol + end +end + +@testitem "post_rule_annotations! stores a RuleInputArgumentsRecord" setup=[ + RuleInputArgumentsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + post_rule_annotations!, + InputArgumentsAnnotations, + RuleInputArgumentsRecord, + get_rule_input_arguments + + ann = AnnotationDict() + mapping = RuleInputArgumentsTestUtils.MockMapping(:out) + messages = (:msg1, :msg2) + marginals = (:mar1,) + result = :the_result + + post_rule_annotations!( + InputArgumentsAnnotations(), ann, mapping, messages, marginals, result + ) + + record = get_rule_input_arguments(ann) + @test record isa RuleInputArgumentsRecord + @test record.mapping === mapping + @test record.messages === messages + @test record.marginals === marginals + @test record.result === result +end + +@testitem "post_product_annotations! merges two RuleInputArgumentsRecord into ProductInputArgumentsRecord" setup=[ + RuleInputArgumentsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + annotate!, + post_product_annotations!, + InputArgumentsAnnotations, + RuleInputArgumentsRecord, + ProductInputArgumentsRecord, + get_rule_input_arguments + + left_record = RuleInputArgumentsRecord(RuleInputArgumentsTestUtils.MockMapping(:left), nothing, nothing, :left_result) + right_record = RuleInputArgumentsRecord(RuleInputArgumentsTestUtils.MockMapping(:right), nothing, nothing, :right_result) + + left_ann = AnnotationDict() + right_ann = AnnotationDict() + annotate!(left_ann, :rule_input_arguments, left_record) + annotate!(right_ann, :rule_input_arguments, right_record) + + merged = post_product_annotations!( + (InputArgumentsAnnotations(),), + left_ann, + right_ann, + nothing, + nothing, + nothing, + ) + + prod = get_rule_input_arguments(merged) + @test prod isa ProductInputArgumentsRecord + @test length(prod.mappings) == 2 + @test prod.mappings[1] === left_record + @test prod.mappings[2] === right_record +end + +@testitem "post_product_annotations! merges record (left) and prod (right)" setup=[ + RuleInputArgumentsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + annotate!, + post_product_annotations!, + InputArgumentsAnnotations, + RuleInputArgumentsRecord, + ProductInputArgumentsRecord, + get_rule_input_arguments + + r1 = RuleInputArgumentsRecord( + RuleInputArgumentsTestUtils.MockMapping(:r1), nothing, nothing, :res1 + ) + r2 = RuleInputArgumentsRecord( + RuleInputArgumentsTestUtils.MockMapping(:r2), nothing, nothing, :res2 + ) + r3 = RuleInputArgumentsRecord( + RuleInputArgumentsTestUtils.MockMapping(:r3), nothing, nothing, :res3 + ) + + left_ann = AnnotationDict() + right_ann = AnnotationDict() + annotate!(left_ann, :rule_input_arguments, r1) + annotate!( + right_ann, :rule_input_arguments, ProductInputArgumentsRecord([r2, r3]) + ) + + merged = post_product_annotations!( + (InputArgumentsAnnotations(),), + left_ann, + right_ann, + nothing, + nothing, + nothing, + ) + + prod = get_rule_input_arguments(merged) + @test prod isa ProductInputArgumentsRecord + @test length(prod.mappings) == 3 + @test prod.mappings[1] === r1 + @test prod.mappings[2] === r2 + @test prod.mappings[3] === r3 +end + +@testitem "post_product_annotations! merges prod (left) and record (right)" setup=[ + RuleInputArgumentsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + annotate!, + post_product_annotations!, + InputArgumentsAnnotations, + RuleInputArgumentsRecord, + ProductInputArgumentsRecord, + get_rule_input_arguments + + r1 = RuleInputArgumentsRecord( + RuleInputArgumentsTestUtils.MockMapping(:r1), nothing, nothing, :res1 + ) + r2 = RuleInputArgumentsRecord( + RuleInputArgumentsTestUtils.MockMapping(:r2), nothing, nothing, :res2 + ) + r3 = RuleInputArgumentsRecord( + RuleInputArgumentsTestUtils.MockMapping(:r3), nothing, nothing, :res3 + ) + + left_ann = AnnotationDict() + right_ann = AnnotationDict() + annotate!( + left_ann, :rule_input_arguments, ProductInputArgumentsRecord([r1, r2]) + ) + annotate!(right_ann, :rule_input_arguments, r3) + + merged = post_product_annotations!( + (InputArgumentsAnnotations(),), + left_ann, + right_ann, + nothing, + nothing, + nothing, + ) + + prod = get_rule_input_arguments(merged) + @test prod isa ProductInputArgumentsRecord + @test length(prod.mappings) == 3 + @test prod.mappings[1] === r1 + @test prod.mappings[2] === r2 + @test prod.mappings[3] === r3 +end + +@testitem "post_product_annotations! merges two ProductInputArgumentsRecord" setup=[ + RuleInputArgumentsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + annotate!, + post_product_annotations!, + InputArgumentsAnnotations, + RuleInputArgumentsRecord, + ProductInputArgumentsRecord, + get_rule_input_arguments + + r1 = RuleInputArgumentsRecord( + RuleInputArgumentsTestUtils.MockMapping(:r1), nothing, nothing, :res1 + ) + r2 = RuleInputArgumentsRecord( + RuleInputArgumentsTestUtils.MockMapping(:r2), nothing, nothing, :res2 + ) + r3 = RuleInputArgumentsRecord( + RuleInputArgumentsTestUtils.MockMapping(:r3), nothing, nothing, :res3 + ) + r4 = RuleInputArgumentsRecord( + RuleInputArgumentsTestUtils.MockMapping(:r4), nothing, nothing, :res4 + ) + + left_ann = AnnotationDict() + right_ann = AnnotationDict() + annotate!( + left_ann, :rule_input_arguments, ProductInputArgumentsRecord([r1, r2]) + ) + annotate!( + right_ann, :rule_input_arguments, ProductInputArgumentsRecord([r3, r4]) + ) + + merged = post_product_annotations!( + (InputArgumentsAnnotations(),), + left_ann, + right_ann, + nothing, + nothing, + nothing, + ) + + prod = get_rule_input_arguments(merged) + @test prod isa ProductInputArgumentsRecord + @test length(prod.mappings) == 4 + @test prod.mappings[1] === r1 + @test prod.mappings[2] === r2 + @test prod.mappings[3] === r3 + @test prod.mappings[4] === r4 +end + +@testitem "Base.show for RuleInputArgumentsRecord" begin + import ReactiveMP: RuleInputArgumentsRecord, MessageMapping, Marginalisation + import BayesBase: PointMass + + struct ShowRecordNode end + + mapping = MessageMapping( + ShowRecordNode, + Val(:out), + Marginalisation(), + Val((:in1, :in2)), + Val((:q1,)), + "some-meta", + nothing, + ShowRecordNode(), + nothing, + nothing, + ) + + record = RuleInputArgumentsRecord( + mapping, (PointMass(1.0), 2.0), (10.0,), 42.0 + ) + + output = sprint(show, record) + + @test occursin("Rule input arguments:", output) + @test occursin("node:", output) + @test occursin("ShowRecordNode", output) + @test occursin("interface:", output) + @test occursin(":out", output) + @test occursin("constraint:", output) + @test occursin("Marginalisation", output) + @test occursin("meta:", output) + @test occursin("some-meta", output) + @test occursin("msg(in1) = BayesBase.PointMass{Float64}(1.0)", output) + @test occursin("msg(in2) = 2.0", output) + @test occursin("q(q1) = 10.0", output) + @test occursin("result:", output) + @test occursin("42.0", output) +end + +@testitem "Base.show for RuleInputArgumentsRecord skips meta when nothing" begin + import ReactiveMP: RuleInputArgumentsRecord, MessageMapping, Marginalisation + + struct ShowRecordNoMetaNode end + + mapping = MessageMapping( + ShowRecordNoMetaNode, + Val(:out), + Marginalisation(), + Val((:in,)), + nothing, + nothing, + nothing, + ShowRecordNoMetaNode(), + nothing, + nothing, + ) + + record = RuleInputArgumentsRecord(mapping, (1.0,), nothing, 2.0) + output = sprint(show, record) + + @test !occursin("meta:", output) + @test occursin("msg(in) = 1.0", output) +end + +@testitem "Base.show for RuleInputArgumentsRecord skips messages/marginals when nothing" begin + import ReactiveMP: RuleInputArgumentsRecord, MessageMapping, Marginalisation + + struct ShowRecordEmptyInputsNode end + + mapping = MessageMapping( + ShowRecordEmptyInputsNode, + Val(:out), + Marginalisation(), + nothing, + nothing, + nothing, + nothing, + ShowRecordEmptyInputsNode(), + nothing, + nothing, + ) + + record = RuleInputArgumentsRecord(mapping, nothing, nothing, :the_result) + output = sprint(show, record) + + @test !occursin("msg(", output) + @test !occursin("q(", output) + @test occursin("result:", output) + @test occursin("the_result", output) +end + +@testitem "Base.show for ProductInputArgumentsRecord" begin + import ReactiveMP: + RuleInputArgumentsRecord, + ProductInputArgumentsRecord, + MessageMapping, + Marginalisation + + struct ShowProductNodeA end + struct ShowProductNodeB end + + mapping_a = MessageMapping( + ShowProductNodeA, + Val(:out), + Marginalisation(), + Val((:in,)), + nothing, + nothing, + nothing, + ShowProductNodeA(), + nothing, + nothing, + ) + + mapping_b = MessageMapping( + ShowProductNodeB, + Val(:mean), + Marginalisation(), + Val((:x,)), + nothing, + nothing, + nothing, + ShowProductNodeB(), + nothing, + nothing, + ) + + r1 = RuleInputArgumentsRecord(mapping_a, (1.0,), nothing, :res_a) + r2 = RuleInputArgumentsRecord(mapping_b, (2.0,), nothing, :res_b) + prod = ProductInputArgumentsRecord([r1, r2]) + + output = sprint(show, prod) + + @test occursin("Product of 2 rule input arguments:", output) + @test occursin("[1]", output) + @test occursin("[2]", output) + @test occursin("ShowProductNodeA", output) + @test occursin("ShowProductNodeB", output) + @test occursin("res_a", output) + @test occursin("res_b", output) +end + +@testitem "AddonMemory throws an error" begin + import ReactiveMP: AddonMemory + + @test_throws "AddonMemory` has been removed" AddonMemory() + @test_throws "InputArgumentsAnnotations" AddonMemory() +end diff --git a/test/annotations/logscale_tests.jl b/test/annotations/logscale_tests.jl new file mode 100644 index 000000000..2e01e44ad --- /dev/null +++ b/test/annotations/logscale_tests.jl @@ -0,0 +1,145 @@ +@testmodule LogScaleAnnotationsTestUtils begin + import ReactiveMP: AnnotationDict, LogScaleAnnotations + import BayesBase: compute_logscale, PointMass + import BayesBase + + struct CustomDistributionForLogScaleTesting end + + BayesBase.compute_logscale( + ::CustomDistributionForLogScaleTesting, + ::CustomDistributionForLogScaleTesting, + ::CustomDistributionForLogScaleTesting, + ) = 10.0 +end + +@testitem "getlogscale reads from AnnotationDict" begin + import ReactiveMP: AnnotationDict, annotate!, getlogscale + + ann = AnnotationDict() + annotate!(ann, :logscale, 3.0) + + @test getlogscale(ann) == 3.0 +end + +@testitem "getlogscale throws when logscale is not set" begin + import ReactiveMP: AnnotationDict, getlogscale + + ann = AnnotationDict() + + @test_throws KeyError getlogscale(ann) +end + +@testitem "@logscale macro sets logscale annotation via getannotations" begin + import ReactiveMP: AnnotationDict, getlogscale, @logscale + + _annotations = AnnotationDict() + getannotations = () -> _annotations + @logscale 2.5 + + @test getlogscale(_annotations) == 2.5 +end + +@testitem "post_rule_annotations! is no-op when logscale already annotated" setup = [ + LogScaleAnnotationsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + annotate!, + getlogscale, + post_rule_annotations!, + LogScaleAnnotations + + ann = AnnotationDict() + annotate!(ann, :logscale, 7.0) + + post_rule_annotations!( + LogScaleAnnotations(), ann, nothing, nothing, nothing, nothing + ) + + @test getlogscale(ann) == 7.0 +end + +@testitem "post_rule_annotations! sets logscale to 0 when all messages are PointMass" setup = [ + LogScaleAnnotationsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + getlogscale, + post_rule_annotations!, + LogScaleAnnotations, + Message + import BayesBase: PointMass + + ann = AnnotationDict() + messages = (Message(PointMass(1.0), false, false), Message(PointMass(2.0), false, false)) + + post_rule_annotations!( + LogScaleAnnotations(), ann, nothing, messages, nothing, nothing + ) + + @test getlogscale(ann) == 0 +end + +@testitem "post_rule_annotations! sets logscale to 0 when all marginals are PointMass" setup = [ + LogScaleAnnotationsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + getlogscale, + post_rule_annotations!, + LogScaleAnnotations, + Marginal + import BayesBase: PointMass + + ann = AnnotationDict() + marginals = (Marginal(PointMass(1.0), false, false),) + + post_rule_annotations!( + LogScaleAnnotations(), ann, nothing, nothing, marginals, nothing + ) + + @test getlogscale(ann) == 0 +end + +@testitem "post_rule_annotations! errors when logscale not set and inputs are not all PointMass" setup = [ + LogScaleAnnotationsTestUtils +] begin + import ReactiveMP: + AnnotationDict, post_rule_annotations!, LogScaleAnnotations + + ann = AnnotationDict() + messages = (Message(LogScaleAnnotationsTestUtils.CustomDistributionForLogScaleTesting(), false, false),) + + @test_throws "Log-scale annotation has not been set" post_rule_annotations!( + LogScaleAnnotations(), ann, nothing, messages, nothing, nothing + ) +end + +@testitem "post_product_annotations! with LogScaleAnnotations sums logscales and adds compute_logscale" setup = [ + LogScaleAnnotationsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + annotate!, + getlogscale, + post_product_annotations!, + LogScaleAnnotations + + left_ann = AnnotationDict() + right_ann = AnnotationDict() + annotate!(left_ann, :logscale, 1.0) + annotate!(right_ann, :logscale, 2.0) + + dist = LogScaleAnnotationsTestUtils.CustomDistributionForLogScaleTesting() + merged = post_product_annotations!((LogScaleAnnotations(),), left_ann, right_ann, dist, dist, dist) + + # 1.0 + 2.0 + compute_logscale(...) = 1.0 + 2.0 + 10.0 = 13.0 + @test getlogscale(merged) == 13.0 +end + +@testitem "AddonLogScale throws an error" begin + import ReactiveMP: AddonLogScale + + @test_throws "AddonLogScale` has been removed" AddonLogScale() + @test_throws "LogScaleAnnotations" AddonLogScale() +end diff --git a/test/annotations_tests.jl b/test/annotations_tests.jl new file mode 100644 index 000000000..6667a6a65 --- /dev/null +++ b/test/annotations_tests.jl @@ -0,0 +1,214 @@ +@testmodule AnnotationsTestUtils begin + import ReactiveMP: + AbstractAnnotations, + AnnotationDict, + annotate!, + get_annotation, + post_product_annotations! + import ReactiveMP + + struct Normal + mean::Float64 + std::Float64 + end + + struct SumAnnotations <: AbstractAnnotations end + + function ReactiveMP.post_product_annotations!( + ::SumAnnotations, + merged, + left_ann, + right_ann, + new_dist, + left_dist, + right_dist, + ) + annotate!( + merged, + :sum, + get_annotation(left_ann, :val) + get_annotation(right_ann, :val), + ) + end +end + +@testitem "AnnotationDict can be created" begin + import ReactiveMP: AnnotationDict, annotate!, get_annotation, has_annotation + + ann = AnnotationDict() + + @test !has_annotation(ann, :logscale) + + annotate!(ann, :logscale, 1.0) + + @test has_annotation(ann, :logscale) + @test get_annotation(ann, :logscale) == 1.0 + @test @inferred(get_annotation(ann, Float64, :logscale)) == 1.0 +end + +@testitem "AnnotationDict can be copied with copy constructor" begin + import ReactiveMP: AnnotationDict, annotate!, get_annotation, has_annotation + + original = AnnotationDict() + annotate!(original, :foo, 1) + annotate!(original, :bar, 2) + + copied = AnnotationDict(original) + + @test has_annotation(copied, :foo) + @test has_annotation(copied, :bar) + @test get_annotation(copied, :foo) == 1 + @test get_annotation(copied, :bar) == 2 + + # mutating the copy does not affect the original + annotate!(copied, :foo, 99) + @test get_annotation(original, :foo) == 1 +end + +@testitem "AnnotationDict isempty" begin + import ReactiveMP: AnnotationDict, annotate! + + ann = AnnotationDict() + @test isempty(ann) + + annotate!(ann, :foo, 1) + @test !isempty(ann) +end + +@testitem "AnnotationDict show" begin + import ReactiveMP: AnnotationDict, annotate! + + ann = AnnotationDict() + @test repr(ann) == "AnnotationDict()" + + annotate!(ann, :logscale, 1.0) + @test occursin("logscale", repr(ann)) + @test occursin("1.0", repr(ann)) +end + +@testitem "AnnotationDict does not allocate on simple creation" begin + import ReactiveMP: AnnotationDict, has_annotation + + function foo() + ann = AnnotationDict() + return has_annotation(ann, :logscale) + end + + foo() + + @test @allocated(foo()) === 0 +end + +@testitem "post_product_annotations! with no processors returns empty AnnotationDict" setup=[ + AnnotationsTestUtils +] begin + import ReactiveMP: + AnnotationDict, annotate!, has_annotation, post_product_annotations! + + left_ann = AnnotationDict() + right_ann = AnnotationDict() + annotate!(left_ann, :foo, 1) + annotate!(right_ann, :foo, 2) + + dist = AnnotationsTestUtils.Normal(0.0, 1.0) + + for processors in (nothing, ()) + result = post_product_annotations!( + processors, left_ann, right_ann, dist, dist, dist + ) + @test result isa AnnotationDict + @test !has_annotation(result, :foo) + end +end + +@testitem "post_product_annotations! calls per-processor post_product_annotations! for each processor" setup=[ + AnnotationsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + annotate!, + get_annotation, + has_annotation, + post_product_annotations! + + left_ann = AnnotationDict() + right_ann = AnnotationDict() + annotate!(left_ann, :val, 3) + annotate!(right_ann, :val, 7) + + dist = AnnotationsTestUtils.Normal(0.0, 1.0) + + result = post_product_annotations!( + (AnnotationsTestUtils.SumAnnotations(),), + left_ann, + right_ann, + dist, + dist, + dist, + ) + @test has_annotation(result, :sum) + @test get_annotation(result, :sum) == 10 +end + +@testitem "post_product_annotations! with missing left_dist copies right_ann" setup=[ + AnnotationsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + annotate!, + get_annotation, + has_annotation, + post_product_annotations! + + left_ann = AnnotationDict() + right_ann = AnnotationDict() + annotate!(right_ann, :logscale, 5.0) + + dist = AnnotationsTestUtils.Normal(0.0, 1.0) + + result = post_product_annotations!( + nothing, left_ann, right_ann, dist, missing, dist + ) + @test has_annotation(result, :logscale) + @test get_annotation(result, :logscale) == 5.0 +end + +@testitem "post_product_annotations! with missing right_dist copies left_ann" setup=[ + AnnotationsTestUtils +] begin + import ReactiveMP: + AnnotationDict, + annotate!, + get_annotation, + has_annotation, + post_product_annotations! + + left_ann = AnnotationDict() + right_ann = AnnotationDict() + annotate!(left_ann, :logscale, 3.0) + + dist = AnnotationsTestUtils.Normal(0.0, 1.0) + + result = post_product_annotations!( + nothing, left_ann, right_ann, dist, dist, missing + ) + @test has_annotation(result, :logscale) + @test get_annotation(result, :logscale) == 3.0 +end + +@testitem "post_product_annotations! with both dists missing returns empty AnnotationDict" setup=[ + AnnotationsTestUtils +] begin + import ReactiveMP: + AnnotationDict, annotate!, has_annotation, post_product_annotations! + + left_ann = AnnotationDict() + right_ann = AnnotationDict() + annotate!(left_ann, :logscale, 1.0) + annotate!(right_ann, :logscale, 2.0) + + result = post_product_annotations!( + nothing, left_ann, right_ann, missing, missing, missing + ) + @test result isa AnnotationDict + @test !has_annotation(result, :logscale) +end diff --git a/test/callbacks_tests.jl b/test/callbacks_tests.jl new file mode 100644 index 000000000..afdbf1199 --- /dev/null +++ b/test/callbacks_tests.jl @@ -0,0 +1,446 @@ +@testmodule CallbacksTestUtils begin + import ReactiveMP: Event + + struct CustomEvent{E, D} <: Event{E} + data::D + end + + function CustomEvent(E::Symbol, args...) + return CustomEvent{E, typeof(args)}(args) + end + + mutable struct MutableCustomEvent{E, D} <: Event{E} + data::D + state::Any + end + + function MutableCustomEvent(E::Symbol, args...; state = nothing) + return MutableCustomEvent{E, typeof(args)}(args, state) + end + + export CustomEvent, MutableCustomEvent +end + +@testitem "Callbacks handler should do absolutely nothing if no handler exists" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback, Event, generate_span_id + using UUIDs + + # We use here a type stable structure to achieve 0 allocations + struct MyCustomEvent{T} <: Event{:my_custom_event} + a::Int + b::T + c::Vector{Int} + d::Matrix{Int} + end + + event = MyCustomEvent(1, "Hello", [1, 2, 3], [1;;]) + callback_handler = nothing + + function bar(callback_handler, event) + invoke_callback(callback_handler, event) + return nothing + end + + bar(callback_handler, event) + + @test @inferred(bar(callback_handler, event)) === nothing + @test @allocated(bar(callback_handler, event)) === 0 + + if VERSION >= v"1.12.0" + function bar2(callback_handler) + invoke_callback( + callback_handler, CustomEvent(:event1, 1, 2, "asd", [2]) + ) + return 1 + 2 + end + + bar2(callback_handler) + + @test @inferred(bar2(callback_handler)) === 3 + @test @allocated(bar2(callback_handler)) === 0 + end + + mutable struct EventWithTypeStableState <: + Event{:event_with_typestable_state} + internal_state::Bool + end + + function bar3(callback_handler) + event = EventWithTypeStableState(true) + invoke_callback(callback_handler, event) + return event.internal_state + end + + bar3(callback_handler) + + @test @inferred(bar3(callback_handler)) === true + @test @allocated(bar3(callback_handler)) === 0 + + mutable struct EventWithTypeUnstableState <: + Event{:event_with_typeunstable_state} + state + end + + function bar4(callback_handler) + event = EventWithTypeUnstableState(nothing) + invoke_callback(callback_handler, event) + if event.state === nothing + return 1 + end + return [1] + end + + bar4(callback_handler) + + @test @allocated(bar4(callback_handler)) === 0 + + if VERSION >= v"1.12.0" + # Test that span_id does not cause allocations + struct BeforeSuperCoolEvent{I} <: Event{:my_custom_event_243} + span_id::I + end + struct AfterSuperCoolEvent{I} <: Event{:my_custom_event_534} + result::Float64 + span_id::I + end + + function bar5(callback_handler, input::Float64) + span_id = generate_span_id(callback_handler) + + invoke_callback(callback_handler, BeforeSuperCoolEvent(span_id)) + + result = input + 4.0 + + invoke_callback( + callback_handler, AfterSuperCoolEvent(result, span_id) + ) + + return result + end + + @test bar5(callback_handler, 9.0) == 13.0 + @test @allocated(bar5(callback_handler, 9.0)) === 0 + end +end + +@testitem "invoke_callback should return the event" setup = [CallbacksTestUtils] begin + import ReactiveMP: invoke_callback + + event = CustomEvent(:event1, 1, 2) + + # nothing handler returns the event + @test invoke_callback(nothing, event) === event + + # NamedTuple handler returns the event + callbacks = (event1 = (e) -> nothing,) + @test invoke_callback(callbacks, event) === event + + # Dict handler returns the event + dict_callbacks = Dict{Symbol, Any}(:event1 => (e) -> nothing) + @test invoke_callback(dict_callbacks, event) === event + + # Unmatched event still returns the event + @test invoke_callback(callbacks, CustomEvent(:other, 1)) === + CustomEvent(:other, 1) +end + +@testitem "event_name should work on both types and instances" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: event_name, Event + + struct MyEvent <: Event{:my_event} + value::Int + end + + # Works on types + @test event_name(MyEvent) === :my_event + @test event_name(CustomEvent{:foo, Tuple{Int}}) === :foo + + # Works on instances + @test event_name(MyEvent(42)) === :my_event + @test event_name(CustomEvent(:bar, 1, 2)) === :bar + + # Works on built-in event types + @test event_name(ReactiveMP.BeforeMessageRuleCallEvent) === + :before_message_rule_call + @test event_name(ReactiveMP.AfterProductOfTwoMessagesEvent) === + :after_product_of_two_messages +end + +@testitem "It should be possible to define custom callback handlers via handle_event" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback, handle_event, Event + + struct MyCallbackHandler + events + end + + function ReactiveMP.handle_event( + handler::MyCallbackHandler, event::Event{E} + ) where {E} + push!(handler.events, (event = E, data = event.data)) + return nothing + end + + handler = MyCallbackHandler([]) + + @test invoke_callback(handler, CustomEvent(:event1, 1, 1)) isa + CustomEvent{:event1} + @test invoke_callback(handler, CustomEvent(:event2, 2, 3)) isa + CustomEvent{:event2} + + @test length(handler.events) === 2 + @test handler.events[1].event === :event1 + @test handler.events[1].data === (1, 1) + @test handler.events[2].event === :event2 + @test handler.events[2].data === (2, 3) + + @test_throws MethodError invoke_callback( + handler, "unsupported type of event" + ) +end + +@testitem "Custom callback handler can mutate event state" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback, handle_event, Event + + struct StateModifyingHandler end + + function ReactiveMP.handle_event( + ::StateModifyingHandler, event::MutableCustomEvent{:my_event} + ) + event.state = :modified + return nothing + end + + handler = StateModifyingHandler() + event = MutableCustomEvent(:my_event, 1, 2; state = nothing) + + @test event.state === nothing + returned_event = invoke_callback(handler, event) + @test returned_event === event + @test event.state === :modified +end + +@testitem "NamedTuple callback can mutate event state" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback + + callbacks = (my_event = (event) -> begin + event.state = sum(event.data) + end,) + + event = MutableCustomEvent(:my_event, 3, 4; state = nothing) + returned_event = invoke_callback(callbacks, event) + + @test returned_event === event + @test event.state === 7 +end + +@testitem "Dict callback can mutate event state" setup = [CallbacksTestUtils] begin + import ReactiveMP: invoke_callback + + callbacks = Dict{Symbol, Any}( + :my_event => (event) -> begin + event.state = prod(event.data) + end + ) + + event = MutableCustomEvent(:my_event, 3, 5; state = nothing) + returned_event = invoke_callback(callbacks, event) + + @test returned_event === event + @test event.state === 15 +end + +@testitem "invoke_callback error hint for forgotten trailing comma in NamedTuple" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback + + # This simulates the common mistake: `(before_product_of_messages = fn)` without trailing comma. + # Julia parses this as a plain assignment, so `callbacks` becomes just `fn` (a Function). + callbacks = (before_product_of_messages = (event) -> nothing) + + # Verify that Julia indeed parsed this as a Function, not a NamedTuple + @test callbacks isa Function + @test !(callbacks isa NamedTuple) + + err = try + invoke_callback(callbacks, CustomEvent(:before_product_of_messages)) + catch e + e + end + @test err isa MethodError + + # Check that the error hint mentions both possible causes + hint_message = sprint(showerror, err) + @test occursin("handle_event", hint_message) + @test occursin("trailing comma", hint_message) +end + +@testitem "invoke_callback error hint for custom handler with missing method" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback + + # Custom handler that only implements handle_event for :event1 but not :event2 + struct IncompleteHandler end + + ReactiveMP.handle_event(::IncompleteHandler, ::CustomEvent{:event1}) = + nothing + + handler = IncompleteHandler() + + # :event1 works fine + @test invoke_callback(handler, CustomEvent(:event1)) isa + CustomEvent{:event1} + + # :event2 is not implemented — should hit MethodError with a helpful hint + err = try + invoke_callback(handler, CustomEvent(:event2)) + catch e + e + end + @test err isa MethodError + + hint_message = sprint(showerror, err) + @test occursin( + r"ReactiveMP\.handle_event\(::.*IncompleteHandler, event::Event\{:event2\}\) = \.\.\.", + hint_message, + ) + @test occursin( + "You meant to pass a `NamedTuple` as the callbacks handler but forgot the trailing comma.", + hint_message, + ) +end + +@testitem "NamedTuple should be a supported event handler" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback + + callback_handler = ( + sum_event = (event) -> nothing, prod_event = (event) -> nothing + ) + + @test invoke_callback(callback_handler, CustomEvent(:sum_event, 1, 2)) isa + CustomEvent{:sum_event} + @test invoke_callback( + callback_handler, CustomEvent(:prod_event, 1, 2, 3) + ) isa CustomEvent{:prod_event} + @test invoke_callback( + callback_handler, CustomEvent(:other_event, 1, 2, 3) + ) isa CustomEvent{:other_event} +end + +@testitem "Dict{Symbol} should be a supported event handler" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback + + callback_handler = Dict{Symbol, Any}( + :sum_event => (event) -> nothing, :prod_event => (event) -> nothing + ) + + @test invoke_callback(callback_handler, CustomEvent(:sum_event, 1, 2)) isa + CustomEvent{:sum_event} + @test invoke_callback(callback_handler, CustomEvent(:prod_event, 1, 2)) isa + CustomEvent{:prod_event} + @test invoke_callback( + callback_handler, CustomEvent(:other_event, 1, 2, 3) + ) === CustomEvent(:other_event, 1, 2, 3) +end + +@testitem "It should be possible to merge callback handlers" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback, merge_callbacks, handle_event, Event + + # listens to event 1 and event 2 + handler1_events = [] + callback_handler1 = ( + event1 = (event) -> push!(handler1_events, :event1), + event2 = (event) -> push!(handler1_events, :event2), + ) + + # listens to event3 and event 2 + handler2_events = [] + callback_handler2 = ( + event3 = (event) -> push!(handler2_events, :event3), + event2 = (event) -> push!(handler2_events, :event2), + ) + + # only listens to event 2 + struct MyCustomHandler + events + end + + ReactiveMP.handle_event(::MyCustomHandler, ::Event) = nothing + ReactiveMP.handle_event(handler::MyCustomHandler, ::CustomEvent{:event2}) = push!( + handler.events, :event2 + ) + + custom_handler = MyCustomHandler([]) + + merged_handler = merge_callbacks( + callback_handler1, callback_handler2, custom_handler + ) + + for i in 1:5 + invoke_callback(merged_handler, CustomEvent(:event1, 1, 1)) + invoke_callback(merged_handler, CustomEvent(:event2, "hello")) + invoke_callback(merged_handler, CustomEvent(:event3, 3.0)) + end + + @test length(handler1_events) == 10 + @test Set(handler1_events) == Set([:event1, :event2]) + @test length(handler2_events) == 10 + @test Set(handler2_events) == Set([:event3, :event2]) + @test length(custom_handler.events) == 5 + @test Set(custom_handler.events) == Set([:event2]) +end + +@testitem "Merged callbacks should return the event" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback, merge_callbacks + + callback_handler1 = (event1 = (event) -> nothing,) + callback_handler2 = (event1 = (event) -> nothing,) + + merged_handler = merge_callbacks(callback_handler1, callback_handler2) + + event = CustomEvent(:event1, 2, 3) + @test invoke_callback(merged_handler, event) === event +end + +@testitem "Merged callbacks can mutate event state across handlers" setup = [ + CallbacksTestUtils +] begin + import ReactiveMP: invoke_callback, merge_callbacks + + # First handler sets state to 1 + handler1 = (my_event = (event) -> begin + event.state = 1 + end,) + + # Second handler increments state + handler2 = (my_event = (event) -> begin + event.state += 10 + end,) + + merged_handler = merge_callbacks(handler1, handler2) + + event = MutableCustomEvent(:my_event, 1, 2; state = nothing) + returned_event = invoke_callback(merged_handler, event) + + @test returned_event === event + @test event.state === 11 +end diff --git a/test/marginal_tests.jl b/test/marginal_tests.jl index 02ad877e1..07fa58cf5 100644 --- a/test/marginal_tests.jl +++ b/test/marginal_tests.jl @@ -5,23 +5,22 @@ import Base: methods import Base.Iterators: repeated, product import BayesBase: xtlog, mirrorlog - import ReactiveMP: getaddons, as_marginal + import ReactiveMP: getannotations, AnnotationDict, as_marginal import SpecialFunctions: loggamma @testset "Default methods" begin for clamped in (true, false), - initial in (true, false), addons in (1, 2), + initial in (true, false), data in (1, 1.0, Normal(0, 1), Gamma(1, 1), PointMass(1)) - marginal = Marginal(data, clamped, initial, addons) + marginal = Marginal(data, clamped, initial) @test getdata(marginal) === data @test is_clamped(marginal) === clamped @test is_initial(marginal) === initial @test as_marginal(marginal) === marginal - @test getaddons(marginal) === addons + @test getannotations(marginal) isa AnnotationDict @test occursin("Marginal", repr(marginal)) @test occursin(repr(data), repr(marginal)) - @test occursin(repr(addons), repr(marginal)) end dist1 = NormalMeanVariance(0.0, 1.0) @@ -31,8 +30,8 @@ clamped2 in (true, false), initial1 in (true, false), initial2 in (true, false) - msg1 = Marginal(dist1, clamped1, initial1, nothing) - msg2 = Marginal(dist2, clamped2, initial2, nothing) + msg1 = Marginal(dist1, clamped1, initial1) + msg2 = Marginal(dist2, clamped2, initial2) @test getdata((msg1, msg2)) === (dist1, dist2) @test is_clamped((msg1, msg2)) === all([clamped1, clamped2]) @@ -92,7 +91,7 @@ method in methods_to_test T = typeof(distribution) - marginal = Marginal(distribution, false, false, nothing) + marginal = Marginal(distribution, false, false) # Here we check that a specialised method for a particular type T exist ms = methods(method, (T,)) if !isempty(ms) && all(m -> m ∈ distribution_methods, ms) @@ -105,7 +104,7 @@ for distribution in distributions, fn_mean in fn_mean_functions F = typeof(fn_mean) T = typeof(distribution) - marginal = Marginal(distribution, false, false, nothing) + marginal = Marginal(distribution, false, false) # Here we check that a specialised method for a particular type T exist ms = methods(mean, (F, T), ReactiveMP) if !isempty(ms) @@ -133,7 +132,7 @@ rng = MersenneTwister(1234) for distribution in distributions2, method in methods_to_test2 - marginal = Marginal(distribution, false, false, nothing) + marginal = Marginal(distribution, false, false) for _ in 1:3 point = _getpoint(rng, distribution) diff --git a/test/message_tests.jl b/test/message_tests.jl index 81537d3cb..6cfcb36ee 100644 --- a/test/message_tests.jl +++ b/test/message_tests.jl @@ -5,23 +5,27 @@ import Base: methods import Base.Iterators: repeated, product import BayesBase: xtlog, mirrorlog - import ReactiveMP: getaddons, multiply_messages, as_message + import ReactiveMP: + getannotations, + AnnotationDict, + compute_product_of_two_messages, + MessageProductContext, + as_message import SpecialFunctions: loggamma @testset "Default methods" begin for clamped in (true, false), - initial in (true, false), addons in (1, 2), + initial in (true, false), data in (1, 1.0, Normal(0, 1), Gamma(1, 1), PointMass(1)) - msg = Message(data, clamped, initial, addons) + msg = Message(data, clamped, initial) @test getdata(msg) === data @test is_clamped(msg) === clamped @test is_initial(msg) === initial @test as_message(msg) === msg - @test getaddons(msg) === addons + @test getannotations(msg) isa AnnotationDict @test occursin("Message", repr(msg)) @test occursin(repr(data), repr(msg)) - @test occursin(repr(addons), repr(msg)) end dist1 = NormalMeanVariance(0.0, 1.0) @@ -31,8 +35,8 @@ clamped2 in (true, false), initial1 in (true, false), initial2 in (true, false) - msg1 = Message(dist1, clamped1, initial1, nothing) - msg2 = Message(dist2, clamped2, initial2, nothing) + msg1 = Message(dist1, clamped1, initial1) + msg2 = Message(dist2, clamped2, initial2) @test getdata((msg1, msg2)) === (dist1, dist2) @test is_clamped((msg1, msg2)) === all([clamped1, clamped2]) @@ -40,108 +44,100 @@ end end - @testset "multiply_messages" begin - × = (x, y) -> multiply_messages(GenericProd(), x, y) + @testset "compute product of two messages" begin + _testvar = ReactiveMP.randomvar() + × = + (x, y) -> compute_product_of_two_messages( + _testvar, MessageProductContext(), x, y + ) dist1 = NormalMeanVariance(randn(), rand()) dist2 = NormalMeanVariance(randn(), rand()) @test getdata( - Message(dist1, false, false, nothing) × - Message(dist2, false, false, nothing), + Message(dist1, false, false) × Message(dist2, false, false) ) == prod(GenericProd(), dist1, dist2) @test getdata( - Message(dist2, false, false, nothing) × - Message(dist1, false, false, nothing), + Message(dist2, false, false) × Message(dist1, false, false) ) == prod(GenericProd(), dist2, dist1) for (left_is_initial, right_is_initial) in product(repeated([true, false], 2)...) @test is_clamped( - Message(dist1, true, left_is_initial, nothing) × - Message(dist2, false, right_is_initial, nothing), + Message(dist1, true, left_is_initial) × + Message(dist2, false, right_is_initial), ) == false @test is_clamped( - Message(dist1, false, left_is_initial, nothing) × - Message(dist2, true, right_is_initial, nothing), + Message(dist1, false, left_is_initial) × + Message(dist2, true, right_is_initial), ) == false @test is_clamped( - Message(dist1, true, left_is_initial, nothing) × - Message(dist2, true, right_is_initial, nothing), + Message(dist1, true, left_is_initial) × + Message(dist2, true, right_is_initial), ) == true @test is_clamped( - Message(dist2, true, left_is_initial, nothing) × - Message(dist1, false, right_is_initial, nothing), + Message(dist2, true, left_is_initial) × + Message(dist1, false, right_is_initial), ) == false @test is_clamped( - Message(dist2, false, left_is_initial, nothing) × - Message(dist1, true, right_is_initial, nothing), + Message(dist2, false, left_is_initial) × + Message(dist1, true, right_is_initial), ) == false @test is_clamped( - Message(dist2, true, left_is_initial, nothing) × - Message(dist1, true, right_is_initial, nothing), + Message(dist2, true, left_is_initial) × + Message(dist1, true, right_is_initial), ) == true end for (left_is_clamped, right_is_clamped) in product(repeated([true, false], 2)...) @test is_initial( - Message(dist1, left_is_clamped, true, nothing) × - Message(dist2, right_is_clamped, true, nothing), + Message(dist1, left_is_clamped, true) × + Message(dist2, right_is_clamped, true), ) == !(left_is_clamped && right_is_clamped) @test is_initial( - Message(dist2, left_is_clamped, true, nothing) × - Message(dist1, right_is_clamped, true, nothing), + Message(dist2, left_is_clamped, true) × + Message(dist1, right_is_clamped, true), ) == !(left_is_clamped && right_is_clamped) @test is_initial( - Message(dist1, left_is_clamped, false, nothing) × - Message(dist2, right_is_clamped, false, nothing), + Message(dist1, left_is_clamped, false) × + Message(dist2, right_is_clamped, false), ) == false @test is_initial( - Message(dist2, left_is_clamped, false, nothing) × - Message(dist1, right_is_clamped, false, nothing), + Message(dist2, left_is_clamped, false) × + Message(dist1, right_is_clamped, false), ) == false end @test is_initial( - Message(dist1, true, true, nothing) × - Message(dist2, true, true, nothing), + Message(dist1, true, true) × Message(dist2, true, true) ) == false @test is_initial( - Message(dist1, true, true, nothing) × - Message(dist2, true, false, nothing), + Message(dist1, true, true) × Message(dist2, true, false) ) == false @test is_initial( - Message(dist1, true, false, nothing) × - Message(dist2, true, true, nothing), + Message(dist1, true, false) × Message(dist2, true, true) ) == false @test is_initial( - Message(dist1, false, true, nothing) × - Message(dist2, true, false, nothing), + Message(dist1, false, true) × Message(dist2, true, false) ) == true @test is_initial( - Message(dist1, true, false, nothing) × - Message(dist2, false, true, nothing), + Message(dist1, true, false) × Message(dist2, false, true) ) == true @test is_initial( - Message(dist2, true, true, nothing) × - Message(dist1, true, true, nothing), + Message(dist2, true, true) × Message(dist1, true, true) ) == false @test is_initial( - Message(dist2, true, true, nothing) × - Message(dist1, true, false, nothing), + Message(dist2, true, true) × Message(dist1, true, false) ) == false @test is_initial( - Message(dist2, true, false, nothing) × - Message(dist1, true, true, nothing), + Message(dist2, true, false) × Message(dist1, true, true) ) == false @test is_initial( - Message(dist2, false, true, nothing) × - Message(dist1, true, false, nothing), + Message(dist2, false, true) × Message(dist1, true, false) ) == true @test is_initial( - Message(dist2, true, false, nothing) × - Message(dist1, false, true, nothing), + Message(dist2, true, false) × Message(dist1, false, true) ) == true end @@ -197,7 +193,7 @@ method in methods_to_test T = typeof(distribution) - message = Message(distribution, false, false, nothing) + message = Message(distribution, false, false) # Here we check that a specialised method for a particular type T exist ms = methods(method, (T,)) if !isempty(ms) && all(m -> m ∈ distribution_methods, ms) @@ -210,7 +206,7 @@ for distribution in distributions, fn_mean in fn_mean_functions F = typeof(fn_mean) T = typeof(distribution) - message = Message(distribution, false, false, nothing) + message = Message(distribution, false, false) # Here we check that a specialised method for a particular type T exist ms = methods(mean, (F, T), ReactiveMP) if !isempty(ms) @@ -238,7 +234,7 @@ rng = MersenneTwister(1234) for distribution in distributions2, method in methods_to_test2 - message = Message(distribution, false, false, nothing) + message = Message(distribution, false, false) for _ in 1:3 point = _getpoint(rng, distribution) @@ -259,7 +255,7 @@ end dmessage = DeferredMessage( messages_stream, marginals_stream, - (a, b) -> Message(a + b, false, false, nothing), + (a, b) -> Message(a + b, false, false), ) # The data cannot be computed since no values were provided yet @@ -286,52 +282,54 @@ end end @testitem "MessageMapping should call `rulefallback` is no rule is available" begin - import ReactiveMP: MessageMapping, getdata + import ReactiveMP: MessageMapping, getdata, AnnotationDict - struct SomeArbitraryNode end + struct SomeArbitraryNodeForRuleFallback end - @node SomeArbitraryNode Stochastic [out, in] + @node SomeArbitraryNodeForRuleFallback Stochastic [out, in] struct NonexistingDistribution end meta = "meta" - addons = () + annotations = nothing mapping_no_rule_fallback = MessageMapping( - SomeArbitraryNode, + SomeArbitraryNodeForRuleFallback, Val(:out), Marginalisation(), Val((:in,)), nothing, meta, - addons, - SomeArbitraryNode(), + annotations, + SomeArbitraryNodeForRuleFallback(), + nothing, nothing, ) - messages = (Message(NonexistingDistribution(), false, false, nothing),) + messages = (Message(NonexistingDistribution(), false, false),) marginals = nothing @test_throws ReactiveMP.RuleMethodError mapping_no_rule_fallback( messages, marginals ) - rulefallback = (args...) -> (args, nothing) + rulefallback = (args...) -> (args) mapping_with_fallback = MessageMapping( - SomeArbitraryNode, + SomeArbitraryNodeForRuleFallback, Val(:out), Marginalisation(), Val((:in,)), nothing, meta, - addons, - SomeArbitraryNode(), + annotations, + SomeArbitraryNodeForRuleFallback(), rulefallback, + nothing, ) @test getdata(mapping_with_fallback(messages, marginals)) == ( - SomeArbitraryNode, + SomeArbitraryNodeForRuleFallback, Val(:out), Marginalisation(), Val((:in,)), @@ -339,7 +337,569 @@ end nothing, marginals, meta, - addons, - SomeArbitraryNode(), + AnnotationDict(), + SomeArbitraryNodeForRuleFallback(), + ) +end + +@testitem "MessageMapping should call provided callbacks handler" begin + import ReactiveMP: MessageMapping, getdata, AnnotationDict + + struct SomeArbitraryNodeCallbacksTests end + + @node SomeArbitraryNodeCallbacksTests Stochastic [out, in] + + @rule SomeArbitraryNodeCallbacksTests(:out, Marginalisation) (m_in::Int,) = + m_in + 1 + + events = [] + + callbacks = ( + before_message_rule_call = (event) -> + push!(events, (event = :before_message_rule_call, data = event)), + after_message_rule_call = (event) -> + push!(events, (event = :after_message_rule_call, data = event)), + ) + + mapping = MessageMapping( + SomeArbitraryNodeCallbacksTests, + Val(:out), + Marginalisation(), + Val((:in,)), + nothing, + nothing, + (), + SomeArbitraryNodeCallbacksTests(), + nothing, + callbacks, + ) + + messages = (Message(1, false, false),) + marginals = nothing + + @test getdata(mapping(messages, marginals)) == 2 + + @test events[1].event == :before_message_rule_call + @test events[1].data.mapping.factornode === + SomeArbitraryNodeCallbacksTests() + @test events[1].data.messages === messages + @test events[1].data.marginals === marginals + + @test events[2].event == :after_message_rule_call + @test events[2].data.mapping.factornode === + SomeArbitraryNodeCallbacksTests() + @test events[2].data.messages === messages + @test events[2].data.marginals === marginals + @test events[2].data.result === 2 + @test events[2].data.annotations isa AnnotationDict +end + +@testmodule MessageProductContextUtils begin + import ReactiveMP: AbstractVariable, AbstractFormConstraint, constrain_form + import ReactiveMP.BayesBase: prod, GenericProd, isapprox + import ReactiveMP + + struct Normal + mean::Float64 + var::Float64 + end + + function prod(::GenericProd, left::Normal, right::Normal) + result_var = 1 / (1 / left.var + 1 / right.var) + result_mean = + result_var * (left.mean / left.var + right.mean / right.var) + return Normal(result_mean, result_var) + end + + function isapprox(left::Normal, right::Normal; kwargs...) + return isapprox(left.mean, right.mean; kwargs...) && + isapprox(left.var, right.var; kwargs...) + end + + struct AbstractVariableForMessageProductContextTests <: AbstractVariable end + + testvar = AbstractVariableForMessageProductContextTests() + + # A simple form constraint that adds +1 to the mean of a Normal distribution + struct AddOneToMeanConstraint <: AbstractFormConstraint end + + constrain_form(::AddOneToMeanConstraint, dist::Normal) = Normal( + dist.mean + 1, dist.var + ) + + # A callback handler that only records events from a specified set + struct SaveOrderOfComputationCallbacks + listen_to::Tuple + events + end + + function ReactiveMP.invoke_callback( + handler::SaveOrderOfComputationCallbacks, event::ReactiveMP.Event{E} + ) where {E} + E ∈ handler.listen_to && + push!(handler.events, (event = E, data = event)) + end + + export Normal, + testvar, AddOneToMeanConstraint, SaveOrderOfComputationCallbacks +end + +@testitem "MessageProductContext should compute product of two messages" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + Message, MessageProductContext, compute_product_of_two_messages, getdata + + context = MessageProductContext() + + msg1 = Message(Normal(0, 1), false, false) + msg2 = Message(Normal(0, 1), false, false) + + result = @inferred( + compute_product_of_two_messages(testvar, context, msg1, msg2) ) + + @test result isa Message + @test getdata(result) === Normal(0, 1 / 2) +end + +@testitem "compute_message_product propagates the `is_clamped` and `is_initial` correctly" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + Message, + MessageProductContext, + compute_product_of_two_messages, + is_clamped, + is_initial + + context = MessageProductContext() + + for left_is_clamped in (true, false), + right_is_clamped in (true, false), left_is_initial in (true, false), + right_is_initial in (true, false) + + msg1 = Message(Normal(0, 1), left_is_clamped, left_is_initial) + msg2 = Message(Normal(0, 1), right_is_clamped, right_is_initial) + + result = @inferred( + compute_product_of_two_messages(testvar, context, msg1, msg2) + ) + + @test result isa Message + + expected_result_is_clamped = left_is_clamped && right_is_clamped + expected_result_is_initial = + !expected_result_is_clamped && + (left_is_clamped || left_is_initial) && + (right_is_clamped || right_is_initial) + + @test is_clamped(result) === expected_result_is_clamped + @test is_initial(result) === expected_result_is_initial + end +end + +@testitem "compute_message_product should support different folding strategies" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + MessageProductContext, + Message, + compute_product_of_messages, + compute_product_of_two_messages, + getdata, + AnnotationDict + + messages = [ + Message(Normal(0, 1), false, false) + Message(Normal(0, 2), false, false) + Message(Normal(0, 3), false, false) + ] + + @testset "From left to right" begin + import ReactiveMP: MessagesProductFromLeftToRight + + listen_to = ( + :before_product_of_two_messages, :after_product_of_two_messages + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + fold_strategy = MessagesProductFromLeftToRight(), + callbacks = handler, + ) + + result = @inferred( + compute_product_of_messages(testvar, context, messages) + ) + + @test result isa Message + @test getdata(result) === Normal(0, 1 / (1 + 1 / 2 + 1 / 3)) + + # 3 messages = 2 products, each product fires a before and after callback + @test length(handler.events) == 4 + @test handler.events[1].event === :before_product_of_two_messages + @test handler.events[2].event === :after_product_of_two_messages + @test handler.events[3].event === :before_product_of_two_messages + @test handler.events[4].event === :after_product_of_two_messages + + # First product: Normal(0,1) × Normal(0,2) — left to right order + @test getdata(handler.events[1].data.left) == Normal(0, 1) + @test getdata(handler.events[1].data.right) == Normal(0, 2) + + # Second product: result of first × Normal(0,3) + @test getdata(handler.events[3].data.left) == + Normal(0, 1 / (1 / 1 + 1 / 2)) + @test getdata(handler.events[3].data.right) == Normal(0, 3) + end + + @testset "From right to left" begin + import ReactiveMP: MessagesProductFromRightToLeft + + listen_to = ( + :before_product_of_two_messages, :after_product_of_two_messages + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + fold_strategy = MessagesProductFromRightToLeft(), + callbacks = handler, + ) + + result = @inferred( + compute_product_of_messages(testvar, context, messages) + ) + + @test result isa Message + @test getdata(result) === Normal(0, 1 / (1 + 1 / 2 + 1 / 3)) + + # 3 messages = 2 products, each product fires a before and after callback + @test length(handler.events) == 4 + @test handler.events[1].event === :before_product_of_two_messages + @test handler.events[2].event === :after_product_of_two_messages + @test handler.events[3].event === :before_product_of_two_messages + @test handler.events[4].event === :after_product_of_two_messages + + # First product: Normal(0,2) × Normal(0,3) — right to left order + @test getdata(handler.events[1].data.left) == Normal(0, 2) + @test getdata(handler.events[1].data.right) == Normal(0, 3) + + # Second product: Normal(0,1) × result of first + @test getdata(handler.events[3].data.left) == Normal(0, 1) + @test getdata(handler.events[3].data.right) == + Normal(0, 1 / (1 / 2 + 1 / 3)) + end + + @testset "Custom fold strategy via Function" begin + # Custom strategy: compute (1 × 3) × 2 + custom_fold = + (variable, context, messages) -> begin + first = compute_product_of_two_messages( + variable, context, messages[1], messages[3] + ) + return compute_product_of_two_messages( + variable, context, first, messages[2] + ) + end + + listen_to = ( + :before_product_of_two_messages, :after_product_of_two_messages + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + fold_strategy = custom_fold, callbacks = handler + ) + + result = @inferred( + compute_product_of_messages(testvar, context, messages) + ) + + @test result isa Message + @test getdata(result) === Normal(0, 1 / (1 + 1 / 2 + 1 / 3)) + + @test length(handler.events) == 4 + + # First product: Normal(0,1) × Normal(0,3) + @test getdata(handler.events[1].data.left) == Normal(0, 1) + @test getdata(handler.events[1].data.right) == Normal(0, 3) + + # Second product: result of first × Normal(0,2) + @test getdata(handler.events[3].data.left) == + Normal(0, 1 / (1 / 1 + 1 / 3)) + @test getdata(handler.events[3].data.right) == Normal(0, 2) + end + + @testset "Before and after callbacks receive correct arguments" begin + import ReactiveMP: MessagesProductFromLeftToRight + + listen_to = ( + :before_product_of_two_messages, :after_product_of_two_messages + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + fold_strategy = MessagesProductFromLeftToRight(), + callbacks = handler, + ) + + msg1 = Message(Normal(0, 1), false, false) + msg2 = Message(Normal(0, 2), false, false) + + result = @inferred( + compute_product_of_messages(testvar, context, [msg1, msg2]) + ) + + @test length(handler.events) == 2 + + # Before callback: variable, context, left, right + before = handler.events[1] + @test before.event === :before_product_of_two_messages + @test before.data.variable === testvar + @test before.data.context === context + @test getdata(before.data.left) == Normal(0, 1) + @test getdata(before.data.right) == Normal(0, 2) + + # After callback: variable, context, left, right, result, annotations + after = handler.events[2] + @test after.event === :after_product_of_two_messages + @test after.data.variable === testvar + @test after.data.context === context + @test getdata(after.data.left) == Normal(0, 1) + @test getdata(after.data.right) == Normal(0, 2) + @test after.data.result == result + @test after.data.annotations isa AnnotationDict # AnnotationDict is default + end +end + +@testitem "Form constraint callbacks with FormConstraintCheckEach" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + MessageProductContext, + Message, + FormConstraintCheckEach, + compute_product_of_messages, + compute_product_of_two_messages, + getdata + + messages = [ + Message(Normal(0, 1), false, false) + Message(Normal(0, 2), false, false) + Message(Normal(0, 3), false, false) + ] + + @testset "CheckEach applies form constraint after each pairwise product" begin + listen_to = ( + :before_form_constraint_applied, :after_form_constraint_applied + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + form_constraint = AddOneToMeanConstraint(), + form_constraint_check_strategy = FormConstraintCheckEach(), + callbacks = handler, + ) + + result = compute_product_of_messages(testvar, context, messages) + + # Hand-computed expected values (left-to-right fold with +1 to mean after each product): + # Step 1: prod(Normal(0,1), Normal(0,2)) + # var = 1/(1 + 1/2) = 2/3, mean = (2/3)*(0 + 0) = 0 => Normal(0, 2/3) + # constraint: Normal(0 + 1, 2/3) = Normal(1, 2/3) + # Step 2: prod(Normal(1, 2/3), Normal(0,3)) + # var = 1/(3/2 + 1/3) = 6/11, mean = (6/11)*(1*3/2 + 0*1/3) = 9/11 => Normal(9/11, 6/11) + # constraint: Normal(9/11 + 1, 6/11) = Normal(20/11, 6/11) + @test getdata(result) ≈ Normal(20 / 11, 6 / 11) + + # With CheckEach and 3 messages: 2 pairwise products => 2 before + 2 after = 4 events + @test length(handler.events) == 4 + @test handler.events[1].event === :before_form_constraint_applied + @test handler.events[2].event === :after_form_constraint_applied + @test handler.events[3].event === :before_form_constraint_applied + @test handler.events[4].event === :after_form_constraint_applied + + # All form constraint events should carry the CheckEach strategy + for e in handler.events + @test e.data.strategy === FormConstraintCheckEach() + end + + # First constraint: before gets Normal(0, 2/3), after gets Normal(1, 2/3) + @test handler.events[1].data.distribution ≈ Normal(0, 2 / 3) + @test handler.events[2].data.result ≈ Normal(1, 2 / 3) + + # Second constraint: before gets Normal(9/11, 6/11), after gets Normal(20/11, 6/11) + @test handler.events[3].data.distribution ≈ Normal(9 / 11, 6 / 11) + @test handler.events[4].data.result ≈ Normal(20 / 11, 6 / 11) + end +end + +@testitem "Form constraint callbacks with FormConstraintCheckLast" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + MessageProductContext, + Message, + FormConstraintCheckLast, + compute_product_of_messages, + getdata + + messages = [ + Message(Normal(0, 1), false, false) + Message(Normal(0, 2), false, false) + Message(Normal(0, 3), false, false) + ] + + @testset "CheckLast applies form constraint once at the end" begin + listen_to = ( + :before_form_constraint_applied, :after_form_constraint_applied + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + form_constraint = AddOneToMeanConstraint(), + form_constraint_check_strategy = FormConstraintCheckLast(), + callbacks = handler, + ) + + result = compute_product_of_messages(testvar, context, messages) + + # Hand-computed expected values (left-to-right fold, constraint only at the end): + # Step 1: prod(Normal(0,1), Normal(0,2)) + # var = 2/3, mean = 0 => Normal(0, 2/3) — no constraint + # Step 2: prod(Normal(0, 2/3), Normal(0,3)) + # var = 1/(3/2 + 1/3) = 6/11, mean = 0 => Normal(0, 6/11) — no constraint + # Final constraint: Normal(0 + 1, 6/11) = Normal(1, 6/11) + @test getdata(result) ≈ Normal(1, 6 / 11) + + # With CheckLast, form constraint fires only once (1 before + 1 after) + @test length(handler.events) == 2 + @test handler.events[1].event === :before_form_constraint_applied + @test handler.events[2].event === :after_form_constraint_applied + + # The strategy should be CheckLast + @test handler.events[1].data.strategy === FormConstraintCheckLast() + @test handler.events[2].data.strategy === FormConstraintCheckLast() + + # Before constraint: Normal(0, 6/11), after constraint: Normal(1, 6/11) + @test handler.events[1].data.distribution ≈ Normal(0, 6 / 11) + @test handler.events[2].data.result ≈ Normal(1, 6 / 11) + + # The final result matches the after-constraint distribution + @test getdata(result) ≈ handler.events[2].data.result + end + + @testset "Before/after product of messages callbacks fire around the whole computation" begin + listen_to = (:before_product_of_messages, :after_product_of_messages) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; + form_constraint = AddOneToMeanConstraint(), + form_constraint_check_strategy = FormConstraintCheckLast(), + callbacks = handler, + ) + + result = compute_product_of_messages(testvar, context, messages) + + # BeforeProductOfMessages should be the first event + @test length(handler.events) == 2 + @test handler.events[1].event === :before_product_of_messages + @test handler.events[1].data.variable === testvar + @test handler.events[1].data.context === context + @test handler.events[1].data.messages === messages + + # AfterProductOfMessages should be the last event + @test handler.events[2].event === :after_product_of_messages + @test handler.events[2].data.variable === testvar + @test handler.events[2].data.context === context + @test handler.events[2].data.messages === messages + @test handler.events[2].data.result == result + end +end + +@testitem "Before/after product of messages callbacks" setup = [ + MessageProductContextUtils +] begin + import ReactiveMP: + MessageProductContext, Message, compute_product_of_messages, getdata + + @testset "Fires with default context (no form constraint)" begin + listen_to = (:before_product_of_messages, :after_product_of_messages) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; callbacks = handler) + + messages = [ + Message(Normal(0, 1), false, false) + Message(Normal(0, 2), false, false) + Message(Normal(0, 3), false, false) + ] + + result = compute_product_of_messages(testvar, context, messages) + + # prod(Normal(0,1), Normal(0,2)) = Normal(0, 2/3) + # prod(Normal(0, 2/3), Normal(0,3)) = Normal(0, 6/11) + @test getdata(result) ≈ Normal(0, 6 / 11) + + @test length(handler.events) == 2 + + # Before: receives variable, context, and the original messages + @test handler.events[1].event === :before_product_of_messages + @test handler.events[1].data.variable === testvar + @test handler.events[1].data.context === context + @test handler.events[1].data.messages === messages + + # After: receives variable, context, original messages, and the final result + @test handler.events[2].event === :after_product_of_messages + @test handler.events[2].data.variable === testvar + @test handler.events[2].data.context === context + @test handler.events[2].data.messages === messages + @test getdata(handler.events[2].data.result) ≈ getdata(result) + end + + @testset "Fires with two messages" begin + listen_to = (:before_product_of_messages, :after_product_of_messages) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; callbacks = handler) + + messages = [ + Message(Normal(1, 4), false, false) + Message(Normal(3, 4), false, false) + ] + + result = compute_product_of_messages(testvar, context, messages) + + # prod(Normal(1,4), Normal(3,4)): + # var = 1/(1/4 + 1/4) = 2, mean = 2*(1/4 + 3/4) = 2 + @test getdata(result) ≈ Normal(2, 2) + + @test length(handler.events) == 2 + @test handler.events[1].event === :before_product_of_messages + @test handler.events[2].event === :after_product_of_messages + @test getdata(handler.events[2].data.result) ≈ getdata(result) + end + + @testset "Before fires before any pairwise product, after fires after all" begin + # Listen to all product-related events to verify ordering + listen_to = ( + :before_product_of_messages, + :before_product_of_two_messages, + :after_product_of_two_messages, + :after_product_of_messages, + ) + handler = SaveOrderOfComputationCallbacks(listen_to, []) + context = MessageProductContext(; callbacks = handler) + + messages = [ + Message(Normal(0, 1), false, false) + Message(Normal(0, 2), false, false) + Message(Normal(0, 3), false, false) + ] + + result = compute_product_of_messages(testvar, context, messages) + + # 3 messages => 1 before_all + 2*(before_two + after_two) + 1 after_all = 6 events + @test length(handler.events) == 6 + @test handler.events[1].event === :before_product_of_messages + @test handler.events[2].event === :before_product_of_two_messages + @test handler.events[3].event === :after_product_of_two_messages + @test handler.events[4].event === :before_product_of_two_messages + @test handler.events[5].event === :after_product_of_two_messages + @test handler.events[6].event === :after_product_of_messages + + # The after_product_of_messages result should match the final result + @test getdata(handler.events[6].data.result) ≈ getdata(result) + end end diff --git a/test/nodes/clusters_tests.jl b/test/nodes/clusters_tests.jl index 461fcd081..649d8fea3 100644 --- a/test/nodes/clusters_tests.jl +++ b/test/nodes/clusters_tests.jl @@ -1,9 +1,10 @@ @testitem "FactorNodeLocalMarginal" begin + import Rocket: of, subscribe!, unsubscribe! import ReactiveMP: FactorNodeLocalMarginal, MarginalObservable, - getmarginal, - setmarginal!, + get_stream_of_marginals, + set_stream_of_marginals!, tag, name @@ -12,13 +13,40 @@ @test tag(localmarginal) === Val{:a}() @test occursin("a", repr(localmarginal)) # The stream is not set - @test_throws UndefRefError getmarginal(localmarginal) + @test_throws UndefRefError get_stream_of_marginals(localmarginal) m = MarginalObservable() - setmarginal!(localmarginal, m) + set_stream_of_marginals!(localmarginal, m) - @test getmarginal(localmarginal) === m + @test get_stream_of_marginals(localmarginal) === m + end + + @testset let localmarginal = FactorNodeLocalMarginal(:b) + @test name(localmarginal) === :b + @test tag(localmarginal) === Val{:b}() + @test occursin("b", repr(localmarginal)) + # The stream is not set + @test_throws UndefRefError get_stream_of_marginals(localmarginal) + + m = of(Marginal("message", false, false)) + + set_stream_of_marginals!(localmarginal, m) + + @test get_stream_of_marginals(localmarginal) !== m + + stream_of_marginals = get_stream_of_marginals(localmarginal) + + output_value = [] + + subscription = subscribe!( + stream_of_marginals, (d) -> push!(output_value, d) + ) + + @test length(output_value) === 1 + @test output_value[1] == Marginal("message", false, false) + + unsubscribe!(subscription) end end @@ -27,8 +55,7 @@ end NodeInterface, FactorNodeLocalClusters, getfactorization, - getmarginals, - getmarginal, + get_node_local_marginals, name a = NodeInterface(:a, randomvar()) @@ -40,8 +67,8 @@ end @testset let clusters = FactorNodeLocalClusters( interfaces, ((1, 2, 3),) ) - @test length(getmarginals(clusters)) === 1 - @test name(getmarginal(clusters, 1)) === :a_b_c + @test length(get_node_local_marginals(clusters)) === 1 + @test name(get_node_local_marginals(clusters)[1]) === :a_b_c @test getfactorization(clusters) === ((1, 2, 3),) @test getfactorization(clusters, 1) === (1, 2, 3) end @@ -49,9 +76,9 @@ end @testset let clusters = FactorNodeLocalClusters( interfaces, ((1, 2), (3,)) ) - @test length(getmarginals(clusters)) === 2 - @test name(getmarginal(clusters, 1)) === :a_b - @test name(getmarginal(clusters, 2)) === :c + @test length(get_node_local_marginals(clusters)) === 2 + @test name(get_node_local_marginals(clusters)[1]) === :a_b + @test name(get_node_local_marginals(clusters)[2]) === :c @test getfactorization(clusters) === ((1, 2), (3,)) @test getfactorization(clusters, 1) === (1, 2) @test getfactorization(clusters, 2) === (3,) @@ -60,9 +87,9 @@ end @testset let clusters = FactorNodeLocalClusters( interfaces, ((1,), (2, 3)) ) - @test length(getmarginals(clusters)) === 2 - @test name(getmarginal(clusters, 1)) === :a - @test name(getmarginal(clusters, 2)) === :b_c + @test length(get_node_local_marginals(clusters)) === 2 + @test name(get_node_local_marginals(clusters)[1]) === :a + @test name(get_node_local_marginals(clusters)[2]) === :b_c @test getfactorization(clusters) === ((1,), (2, 3)) @test getfactorization(clusters, 1) === (1,) @test getfactorization(clusters, 2) === (2, 3) @@ -71,10 +98,10 @@ end @testset let clusters = FactorNodeLocalClusters( interfaces, ((1,), (2,), (3,)) ) - @test length(getmarginals(clusters)) === 3 - @test name(getmarginal(clusters, 1)) === :a - @test name(getmarginal(clusters, 2)) === :b - @test name(getmarginal(clusters, 3)) === :c + @test length(get_node_local_marginals(clusters)) === 3 + @test name(get_node_local_marginals(clusters)[1]) === :a + @test name(get_node_local_marginals(clusters)[2]) === :b + @test name(get_node_local_marginals(clusters)[3]) === :c @test getfactorization(clusters) === ((1,), (2,), (3,)) @test getfactorization(clusters, 1) === (1,) @test getfactorization(clusters, 2) === (2,) @@ -194,7 +221,9 @@ end getlocalclusters, initialize_clusters!, getdata, - default_functional_dependencies + default_functional_dependencies, + get_node_local_marginals, + get_stream_of_marginals using BayesBase @@ -244,7 +273,7 @@ end nothing, nothing, nothing, nothing, nothing, nothing ) - @test length(getmarginals(getlocalclusters(node))) === 1 + @test length(get_node_local_marginals(getlocalclusters(node))) === 1 initialize_clusters!( getlocalclusters(node), dependencies, node, options @@ -252,7 +281,9 @@ end @test PointMass(vout + va + vb) == getdata( check_stream_updated_once( - getmarginal(getmarginal(getlocalclusters(node), 1)) + get_stream_of_marginals( + get_node_local_marginals(getlocalclusters(node))[1] + ), ), ) end @@ -272,7 +303,7 @@ end nothing, nothing, nothing, nothing, nothing, nothing ) - @test length(getmarginals(getlocalclusters(node))) === 2 + @test length(get_node_local_marginals(getlocalclusters(node))) === 2 initialize_clusters!( getlocalclusters(node), dependencies, node, options @@ -280,11 +311,14 @@ end @test PointMass(vout + va - vb) == getdata( check_stream_updated_once( - getmarginal(getmarginal(getlocalclusters(node), 1)) + get_stream_of_marginals( + get_node_local_marginals(getlocalclusters(node))[1] + ), ), ) - @test getmarginal(getmarginal(getlocalclusters(node), 2)) === - getmarginal(b, IncludeAll()) + @test get_stream_of_marginals( + get_node_local_marginals(getlocalclusters(node))[2] + ) === get_stream_of_marginals(b) end end @@ -302,7 +336,7 @@ end nothing, nothing, nothing, nothing, nothing, nothing ) - @test length(getmarginals(getlocalclusters(node))) === 2 + @test length(get_node_local_marginals(getlocalclusters(node))) === 2 initialize_clusters!( getlocalclusters(node), dependencies, node, options @@ -310,11 +344,14 @@ end @test PointMass(vout + vb - va) == getdata( check_stream_updated_once( - getmarginal(getmarginal(getlocalclusters(node), 1)) + get_stream_of_marginals( + get_node_local_marginals(getlocalclusters(node))[1] + ), ), ) - @test getmarginal(getmarginal(getlocalclusters(node), 2)) === - getmarginal(a, IncludeAll()) + @test get_stream_of_marginals( + get_node_local_marginals(getlocalclusters(node))[2] + ) === get_stream_of_marginals(a) end end @@ -332,7 +369,7 @@ end nothing, nothing, nothing, nothing, nothing, nothing ) - @test length(getmarginals(getlocalclusters(node))) === 2 + @test length(get_node_local_marginals(getlocalclusters(node))) === 2 initialize_clusters!( getlocalclusters(node), dependencies, node, options @@ -340,11 +377,14 @@ end @test PointMass(va + vb - vout) == getdata( check_stream_updated_once( - getmarginal(getmarginal(getlocalclusters(node), 2)) + get_stream_of_marginals( + get_node_local_marginals(getlocalclusters(node))[2] + ), ), ) - @test getmarginal(getmarginal(getlocalclusters(node), 1)) === - getmarginal(out, IncludeAll()) + @test get_stream_of_marginals( + get_node_local_marginals(getlocalclusters(node))[1] + ) === get_stream_of_marginals(out) end end end diff --git a/test/nodes/dependencies_tests.jl b/test/nodes/dependencies_tests.jl index 3f8c4232f..291b4e0d1 100644 --- a/test/nodes/dependencies_tests.jl +++ b/test/nodes/dependencies_tests.jl @@ -73,13 +73,14 @@ end import ReactiveMP: NodeInterface, FactorNodeLocalMarginal, - getmarginal, collect_latest_marginals, getdata, getrecent, default_functional_dependencies, getlocalclusters, - getmarginals + get_stream_of_marginals, + set_stream_of_marginals!, + get_node_local_marginals struct ArbitraryNodeForCollectLatestMarginals end @@ -98,11 +99,11 @@ end ArbitraryNodeForCollectLatestMarginals ) - a, b, c = getmarginals(getlocalclusters(node)) + a, b, c = get_node_local_marginals(getlocalclusters(node)) - setmarginal!(a, getmarginal(a_v, IncludeAll())) - setmarginal!(b, getmarginal(b_v, IncludeAll())) - setmarginal!(c, getmarginal(c_v, IncludeAll())) + set_stream_of_marginals!(a, get_stream_of_marginals(a_v)) + set_stream_of_marginals!(b, get_stream_of_marginals(b_v)) + set_stream_of_marginals!(c, get_stream_of_marginals(c_v)) @testset let (tag, stream) = collect_latest_marginals( dependencies, node, (a, b, c) @@ -145,8 +146,10 @@ end RandomVariableActivationOptions, functional_dependencies, getinterfaces, - messagein, - getdata + get_stream_of_inbound_messages, + getdata, + get_node_local_marginals, + get_stream_of_marginals struct ArbitraryFactorNode end @@ -196,15 +199,21 @@ end @test interfaces[3] ∉ msg_dependencies_for_c @test isempty(marginal_dependencies_for_c) - @test check_stream_not_updated(messagein(interfaces[1])) - @test check_stream_not_updated(messagein(interfaces[2])) - @test check_stream_not_updated(messagein(interfaces[3])) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[1]) + ) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[2]) + ) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[3]) + ) end @testset "RequireMessageFunctionalDependencies(a = nothing)" begin import ReactiveMP: RequireMessageFunctionalDependencies - dependencies = RequireMessageFunctionalDependencies(a = nothing) + dependencies = RequireMessageFunctionalDependencies(; a = nothing) msg_dependencies_for_a, marginal_dependencies_for_a = functional_dependencies( dependencies, node, interfaces[1], 1 @@ -231,16 +240,22 @@ end @test interfaces[3] ∉ msg_dependencies_for_c @test isempty(marginal_dependencies_for_c) - @test check_stream_not_updated(messagein(interfaces[1])) - @test check_stream_not_updated(messagein(interfaces[2])) - @test check_stream_not_updated(messagein(interfaces[3])) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[1]) + ) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[2]) + ) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[3]) + ) end @testset "RequireMessageFunctionalDependencies(b = ...)" begin import ReactiveMP: RequireMessageFunctionalDependencies for initialmessage in (1, 2.0, "hello") - dependencies = RequireMessageFunctionalDependencies( + dependencies = RequireMessageFunctionalDependencies(; b = initialmessage ) @@ -269,18 +284,24 @@ end @test interfaces[3] ∉ msg_dependencies_for_c @test isempty(marginal_dependencies_for_c) - @test check_stream_not_updated(messagein(interfaces[1])) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[1]) + ) @test getdata( - check_stream_updated_once(messagein(interfaces[2])) + check_stream_updated_once( + get_stream_of_inbound_messages(interfaces[2]) + ), ) === initialmessage - @test check_stream_not_updated(messagein(interfaces[3])) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[3]) + ) end end @testset "RequireMarginalFunctionalDependencies(a = nothing)" begin import ReactiveMP: RequireMarginalFunctionalDependencies - dependencies = RequireMarginalFunctionalDependencies(a = nothing) + dependencies = RequireMarginalFunctionalDependencies(; a = nothing) msg_dependencies_for_a, marginal_dependencies_for_a = functional_dependencies( dependencies, node, interfaces[1], 1 @@ -298,7 +319,7 @@ end @test !isempty(marginal_dependencies_for_a) @test length(marginal_dependencies_for_a) === 1 @test check_stream_not_updated( - getmarginal(first(marginal_dependencies_for_a)) + get_stream_of_marginals(first(marginal_dependencies_for_a)) ) @test interfaces[1] ∈ msg_dependencies_for_b @@ -318,7 +339,7 @@ end for initialmarginal in (1, 2.0, "hello") import ReactiveMP: RequireMarginalFunctionalDependencies - dependencies = RequireMarginalFunctionalDependencies( + dependencies = RequireMarginalFunctionalDependencies(; a = initialmarginal ) @@ -339,7 +360,9 @@ end @test length(marginal_dependencies_for_a) === 1 @test getdata( check_stream_updated_once( - getmarginal(first(marginal_dependencies_for_a)) + get_stream_of_marginals( + first(marginal_dependencies_for_a) + ), ), ) === initialmarginal @@ -406,7 +429,7 @@ end @testset "RequireMessageFunctionalDependencies(a = nothing)" begin import ReactiveMP: RequireMessageFunctionalDependencies - dependencies = RequireMessageFunctionalDependencies(a = nothing) + dependencies = RequireMessageFunctionalDependencies(; a = nothing) msg_dependencies_for_a, marginal_dependencies_for_a = functional_dependencies( dependencies, node, interfaces[1], 1 @@ -434,16 +457,22 @@ end @test :c ∉ name.(marginal_dependencies_for_c) @test isempty(msg_dependencies_for_c) - @test check_stream_not_updated(messagein(interfaces[1])) - @test check_stream_not_updated(messagein(interfaces[2])) - @test check_stream_not_updated(messagein(interfaces[3])) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[1]) + ) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[2]) + ) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[3]) + ) end @testset "RequireMessageFunctionalDependencies(b = ...)" begin import ReactiveMP: RequireMessageFunctionalDependencies for initialmessage in (1, 2.0, "hello") - dependencies = RequireMessageFunctionalDependencies( + dependencies = RequireMessageFunctionalDependencies(; b = initialmessage ) @@ -473,18 +502,24 @@ end @test :c ∉ name.(marginal_dependencies_for_c) @test isempty(msg_dependencies_for_c) - @test check_stream_not_updated(messagein(interfaces[1])) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[1]) + ) @test getdata( - check_stream_updated_once(messagein(interfaces[2])) + check_stream_updated_once( + get_stream_of_inbound_messages(interfaces[2]) + ), ) === initialmessage - @test check_stream_not_updated(messagein(interfaces[3])) + @test check_stream_not_updated( + get_stream_of_inbound_messages(interfaces[3]) + ) end end @testset "RequireMarginalFunctionalDependencies(a = nothing)" begin import ReactiveMP: RequireMarginalFunctionalDependencies - dependencies = RequireMarginalFunctionalDependencies(a = nothing) + dependencies = RequireMarginalFunctionalDependencies(; a = nothing) msg_dependencies_for_a, marginal_dependencies_for_a = functional_dependencies( dependencies, node, interfaces[1], 1 @@ -501,7 +536,7 @@ end @test :c ∈ name.(marginal_dependencies_for_a) @test isempty(msg_dependencies_for_a) @test check_stream_not_updated( - getmarginal(first(marginal_dependencies_for_a)) + get_stream_of_marginals(first(marginal_dependencies_for_a)) ) @test :a ∈ name.(marginal_dependencies_for_b) @@ -521,7 +556,7 @@ end for initialmarginal in (1, 2.0, "hello") import ReactiveMP: RequireMarginalFunctionalDependencies - dependencies = RequireMarginalFunctionalDependencies( + dependencies = RequireMarginalFunctionalDependencies(; a = initialmarginal ) @@ -541,7 +576,9 @@ end @test isempty(msg_dependencies_for_a) @test getdata( check_stream_updated_once( - getmarginal(first(marginal_dependencies_for_a)) + get_stream_of_marginals( + first(marginal_dependencies_for_a) + ), ), ) === initialmarginal @@ -640,7 +677,7 @@ end @testset "use_a metadata results in CustomDependencyA" begin options_a = FactorNodeActivationOptions( - :use_a, nothing, nothing, nothing, AsapScheduler(), nothing + :use_a, nothing, nothing, nothing, nothing, nothing ) deps = collect_functional_dependencies(CustomMetaNode, options_a) @test deps isa CustomDependencyA @@ -648,7 +685,7 @@ end @testset "use_b metadata results in CustomDependencyB" begin options_b = FactorNodeActivationOptions( - :use_b, nothing, nothing, nothing, AsapScheduler(), nothing + :use_b, nothing, nothing, nothing, nothing, nothing ) deps = collect_functional_dependencies(CustomMetaNode, options_b) @test deps isa CustomDependencyB @@ -656,7 +693,7 @@ end @testset "no metadata falls back to default dependencies" begin options_default = FactorNodeActivationOptions( - nothing, nothing, nothing, nothing, AsapScheduler(), nothing + nothing, nothing, nothing, nothing, nothing, nothing ) deps = collect_functional_dependencies(CustomMetaNode, options_default) @test deps isa DefaultFunctionalDependencies @@ -670,7 +707,7 @@ end ((1,),), ) options_a = FactorNodeActivationOptions( - :use_a, nothing, nothing, nothing, AsapScheduler(), nothing + :use_a, nothing, nothing, nothing, nothing, nothing ) deps_a = collect_functional_dependencies(CustomMetaNode, options_a) activate!(node_a, options_a) @@ -689,7 +726,7 @@ end ((1,),), ) options_b = FactorNodeActivationOptions( - :use_b, nothing, nothing, nothing, AsapScheduler(), nothing + :use_b, nothing, nothing, nothing, nothing, nothing ) deps_b = collect_functional_dependencies(CustomMetaNode, options_b) activate!(node_b, options_b) diff --git a/test/nodes/interfaces_tests.jl b/test/nodes/interfaces_tests.jl index 07ab744ba..ca0a1da42 100644 --- a/test/nodes/interfaces_tests.jl +++ b/test/nodes/interfaces_tests.jl @@ -4,8 +4,8 @@ import ReactiveMP: AbstractVariable, NodeInterface, - messageout, - messagein, + get_stream_of_outbound_messages, + get_stream_of_inbound_messages, tag, getvariable, MessageObservable, @@ -13,52 +13,54 @@ name struct AbstractVariableImplemention <: AbstractVariable - messageout::MessageObservable + stream_of_outbound_messages::MessageObservable end - ReactiveMP.create_messagein!(variable::AbstractVariableImplemention) = ( - variable.messageout, 1 + ReactiveMP.create_new_stream_of_inbound_messages!(variable::AbstractVariableImplemention) = ( + variable.stream_of_outbound_messages, 1 ) - ReactiveMP.messageout(variable::AbstractVariableImplemention, ::Int) = - variable.messageout + ReactiveMP.get_stream_of_outbound_messages( + variable::AbstractVariableImplemention, ::Int + ) = variable.stream_of_outbound_messages - varmessageout = MessageObservable() + stream_of_outbound_messages = MessageObservable() stream = Subject(AbstractMessage) - connect!(varmessageout, stream) - variable = AbstractVariableImplemention(varmessageout) + connect!(stream_of_outbound_messages, stream) + variable = AbstractVariableImplemention(stream_of_outbound_messages) interface = NodeInterface(:name, variable) @test name(interface) === :name @test occursin("name", repr(interface)) @test tag(interface) === Val{:name}() @test getvariable(interface) === variable - @test messagein(interface) === varmessageout + @test get_stream_of_inbound_messages(interface) === + stream_of_outbound_messages actor = keep(AbstractMessage) - subscription = subscribe!(messageout(interface), actor) + subscription = subscribe!(get_stream_of_outbound_messages(interface), actor) - next!(stream, Message(1, false, false, nothing)) + next!(stream, Message(1, false, false)) - @test getvalues(actor) == [Message(1, false, false, nothing)] + @test getvalues(actor) == [Message(1, false, false)] - next!(stream, Message(2, false, false, nothing)) - next!(stream, Message(3, false, false, nothing)) + next!(stream, Message(2, false, false)) + next!(stream, Message(3, false, false)) @test getvalues(actor) == [ - Message(1, false, false, nothing), - Message(2, false, false, nothing), - Message(3, false, false, nothing), + Message(1, false, false), + Message(2, false, false), + Message(3, false, false), ] unsubscribe!(subscription) - next!(stream, Message(4, false, false, nothing)) - next!(stream, Message(5, false, false, nothing)) + next!(stream, Message(4, false, false)) + next!(stream, Message(5, false, false)) @test getvalues(actor) == [ - Message(1, false, false, nothing), - Message(2, false, false, nothing), - Message(3, false, false, nothing), + Message(1, false, false), + Message(2, false, false), + Message(3, false, false), ] end diff --git a/test/nodes/predefined/autoregressive_tests.jl b/test/nodes/predefined/autoregressive_tests.jl index 8ce78cf01..9349c04a2 100644 --- a/test/nodes/predefined/autoregressive_tests.jl +++ b/test/nodes/predefined/autoregressive_tests.jl @@ -11,10 +11,10 @@ q_γ = GammaShapeRate(2.0, 3.0) marginals = ( - Marginal(q_y, false, false, nothing), - Marginal(q_x, false, false, nothing), - Marginal(q_θ, false, false, nothing), - Marginal(q_γ, false, false, nothing), + Marginal(q_y, false, false), + Marginal(q_x, false, false), + Marginal(q_θ, false, false), + Marginal(q_γ, false, false), ) @test score( AverageEnergy(), @@ -30,10 +30,10 @@ q_γ = GammaShapeRate(2.0, 3.0) marginals = ( - Marginal(q_y, false, false, nothing), - Marginal(q_x, false, false, nothing), - Marginal(q_θ, false, false, nothing), - Marginal(q_γ, false, false, nothing), + Marginal(q_y, false, false), + Marginal(q_x, false, false), + Marginal(q_θ, false, false), + Marginal(q_γ, false, false), ) @test score( AverageEnergy(), @@ -48,9 +48,9 @@ q_γ = GammaShapeRate(2.0, 3.0) marginals = ( - Marginal(q_y_x, false, false, nothing), - Marginal(q_θ, false, false, nothing), - Marginal(q_γ, false, false, nothing), + Marginal(q_y_x, false, false), + Marginal(q_θ, false, false), + Marginal(q_γ, false, false), ) @test score( AverageEnergy(), diff --git a/test/nodes/predefined/bifm_helper_tests.jl b/test/nodes/predefined/bifm_helper_tests.jl index dc30eb1a7..62b987b5a 100644 --- a/test/nodes/predefined/bifm_helper_tests.jl +++ b/test/nodes/predefined/bifm_helper_tests.jl @@ -9,16 +9,10 @@ Val{(:out, :in)}(), ( Marginal( - MvNormalMeanCovariance([1, 1], [2 0; 0 3]), - false, - false, - nothing, + MvNormalMeanCovariance([1, 1], [2 0; 0 3]), false, false ), Marginal( - MvNormalMeanCovariance([1, 1], [2 0; 0 3]), - false, - false, - nothing, + MvNormalMeanCovariance([1, 1], [2 0; 0 3]), false, false ), ), nothing, @@ -30,16 +24,10 @@ Val{(:out, :in)}(), ( Marginal( - MvNormalMeanCovariance([1, 2], [2 0; 0 1]), - false, - false, - nothing, + MvNormalMeanCovariance([1, 2], [2 0; 0 1]), false, false ), Marginal( - MvNormalMeanPrecision([1, 2], [0.5 0; 0 1]), - false, - false, - nothing, + MvNormalMeanPrecision([1, 2], [0.5 0; 0 1]), false, false ), ), nothing, diff --git a/test/nodes/predefined/binomial_polya_tests.jl b/test/nodes/predefined/binomial_polya_tests.jl index 064c69b0b..93a54fd9c 100644 --- a/test/nodes/predefined/binomial_polya_tests.jl +++ b/test/nodes/predefined/binomial_polya_tests.jl @@ -8,14 +8,13 @@ BinomialPolya, Val{(:y, :x, :n, :β)}(), ( - Marginal(PointMass(1), false, false, nothing), - Marginal(PointMass([0.1, 0.2]), false, false, nothing), - Marginal(PointMass(5), false, false, nothing), + Marginal(PointMass(1), false, false), + Marginal(PointMass([0.1, 0.2]), false, false), + Marginal(PointMass(5), false, false), Marginal( MvNormalWeightedMeanPrecision(zeros(2), diageye(2)), false, false, - nothing, ), ), nothing, @@ -29,14 +28,13 @@ BinomialPolya, Val{(:y, :x, :n, :β)}(), ( - Marginal(PointMass(1), false, false, nothing), - Marginal(PointMass([0.1, 0.2]), false, false, nothing), - Marginal(PointMass(5), false, false, nothing), + Marginal(PointMass(1), false, false), + Marginal(PointMass([0.1, 0.2]), false, false), + Marginal(PointMass(5), false, false), Marginal( MvNormalWeightedMeanPrecision(zeros(2), diageye(2)), false, false, - nothing, ), ), meta, diff --git a/test/nodes/predefined/continuous_transition_tests.jl b/test/nodes/predefined/continuous_transition_tests.jl index eeee278d9..7cb1f9484 100644 --- a/test/nodes/predefined/continuous_transition_tests.jl +++ b/test/nodes/predefined/continuous_transition_tests.jl @@ -19,15 +19,15 @@ q_W = Wishart(dy + 1, diageye(dy)) marginals_st = ( - Marginal(q_y_x, false, false, nothing), - Marginal(q_a, false, false, nothing), - Marginal(q_W, false, false, nothing), + Marginal(q_y_x, false, false), + Marginal(q_a, false, false), + Marginal(q_W, false, false), ) marginals_mf = ( - Marginal(q_y, false, false, nothing), - Marginal(q_x, false, false, nothing), - Marginal(q_a, false, false, nothing), - Marginal(q_W, false, false, nothing), + Marginal(q_y, false, false), + Marginal(q_x, false, false), + Marginal(q_a, false, false), + Marginal(q_W, false, false), ) # 12,992 is a result of manual calculation diff --git a/test/nodes/predefined/delta/delta_tests.jl b/test/nodes/predefined/delta/delta_tests.jl index 749ab6749..60b36c692 100644 --- a/test/nodes/predefined/delta/delta_tests.jl +++ b/test/nodes/predefined/delta/delta_tests.jl @@ -5,8 +5,8 @@ nodefunction, DeltaMeta, Linearization, - messageout, activate!, + new_observation!, RandomVariableActivationOptions, DataVariableActivationOptions @@ -21,12 +21,12 @@ node = factornode( foo, [(:out, out), (:in, x), (:in, y), (:in, z)], ((1, 2, 3, 4),) ) - meta = DeltaMeta(method = Linearization()) + meta = DeltaMeta(; method = Linearization()) activate!(x, RandomVariableActivationOptions()) activate!(y, DataVariableActivationOptions()) - update!(y, 2.0) + new_observation!(y, 2.0) for xval in rand(10) @test nodefunction(node, meta, Val(:out))(xval) === foo(xval, 2.0, 3.0) @@ -41,8 +41,8 @@ end nodefunction, DeltaMeta, Linearization, - messageout, activate!, + new_observation!, RandomVariableActivationOptions, DataVariableActivationOptions @@ -86,7 +86,7 @@ end i -> i isa Tuple{Symbol, RandomVariable}, in_interfaces ) node = factornode(foo, interfaces, ((1, 2, 3, 4),)) - meta = DeltaMeta(method = Linearization()) + meta = DeltaMeta(; method = Linearization()) foreach(interfaces) do (_, interface) activate_interface(interface) @@ -95,7 +95,7 @@ end # data variable inputs require an actual update foreach(enumerate(in_interfaces)) do (i, interface) if interface isa Tuple{Symbol, DataVariable} - update!(interface[2], vals[i]) + new_observation!(interface[2], vals[i]) end end @@ -122,7 +122,7 @@ end true ) - @test DeltaMeta(method = SupportedApproximationMetßhod()) isa DeltaMeta + @test DeltaMeta(; method = SupportedApproximationMetßhod()) isa DeltaMeta end @testitem "DeltaNode - CVI layout functionality" begin @@ -134,7 +134,8 @@ end DeltaFnNode, DeltaMeta, CVIProjection, - messageout, + get_stream_of_outbound_messages, + new_observation!, activate!, RandomVariableActivationOptions, DataVariableActivationOptions @@ -151,7 +152,7 @@ end node = factornode(f, [(:out, out), (:in, x), (:in, y)], ((1, 2, 3),)) # Test meta creation and compatibility - meta = DeltaMeta(method = CVIProjection()) + meta = DeltaMeta(; method = CVIProjection()) @test meta.method isa CVIProjection @test isnothing(meta.inverse) @@ -160,7 +161,8 @@ end activate!(y, DataVariableActivationOptions()) # Test data variable update propagation - update!(y, 2.0) - @test BayesBase.getpointmass(getdata(Rocket.getrecent(messageout(y, 1)))) ≈ - 2.0 + new_observation!(y, 2.0) + @test BayesBase.getpointmass( + getdata(Rocket.getrecent(get_stream_of_outbound_messages(y, 1))) + ) ≈ 2.0 end diff --git a/test/nodes/predefined/dirichlet_collection_tests.jl b/test/nodes/predefined/dirichlet_collection_tests.jl index e67662567..8649af5e0 100644 --- a/test/nodes/predefined/dirichlet_collection_tests.jl +++ b/test/nodes/predefined/dirichlet_collection_tests.jl @@ -15,8 +15,8 @@ q_a = PointMass(a) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_a, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_a, false, false), ) avg_energy = score( AverageEnergy(), @@ -32,8 +32,8 @@ avg_energy_matrix = 0.0 for (dir, a) in zip(q_out, q_a) marginals = ( - Marginal(dir, false, false, nothing), - Marginal(a, false, false, nothing), + Marginal(dir, false, false), + Marginal(a, false, false), ) avg_energy_matrix += score( AverageEnergy(), diff --git a/test/nodes/predefined/discrete_transition_tests.jl b/test/nodes/predefined/discrete_transition_tests.jl index 7b75ca95a..056e937de 100644 --- a/test/nodes/predefined/discrete_transition_tests.jl +++ b/test/nodes/predefined/discrete_transition_tests.jl @@ -21,8 +21,7 @@ q_a = PointMass(a_matrix) marginals = ( - Marginal(q_out_in, false, false, nothing), - Marginal(q_a, false, false, nothing), + Marginal(q_out_in, false, false), Marginal(q_a, false, false) ) # Expected value calculated by hand @@ -43,8 +42,7 @@ q_a = PointMass(a_matrix) marginals = ( - Marginal(q_out_in, false, false, nothing), - Marginal(q_a, false, false, nothing), + Marginal(q_out_in, false, false), Marginal(q_a, false, false) ) expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf))) @@ -66,8 +64,7 @@ q_a = PointMass(a_matrix) marginals = ( - Marginal(q_out_in, false, false, nothing), - Marginal(q_a, false, false, nothing), + Marginal(q_out_in, false, false), Marginal(q_a, false, false) ) expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf))) @@ -85,8 +82,7 @@ q_a = PointMass(diageye(3)) marginals = ( - Marginal(q_out_in, false, false, nothing), - Marginal(q_a, false, false, nothing), + Marginal(q_out_in, false, false), Marginal(q_a, false, false) ) expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf))) @@ -105,9 +101,9 @@ q_a = PointMass([0.7 0.3; 0.2 0.8]) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_in, false, false, nothing), - Marginal(q_a, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_in, false, false), + Marginal(q_a, false, false), ) contingency = probvec(q_out) * probvec(q_in)' @@ -126,9 +122,9 @@ q_a = PointMass([1.0 0.0; 1.0 0.0]) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_in, false, false, nothing), - Marginal(q_a, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_in, false, false), + Marginal(q_a, false, false), ) contingency = probvec(q_out) * probvec(q_in)' @@ -154,9 +150,9 @@ ) marginals = ( - Marginal(q_out_in, false, false, nothing), - Marginal(q_T1_T2, false, false, nothing), - Marginal(q_a, false, false, nothing), + Marginal(q_out_in, false, false), + Marginal(q_T1_T2, false, false), + Marginal(q_a, false, false), ) contingency = @@ -182,9 +178,9 @@ ) marginals = ( - Marginal(q_out_in, false, false, nothing), - Marginal(q_T1_T2, false, false, nothing), - Marginal(q_a, false, false, nothing), + Marginal(q_out_in, false, false), + Marginal(q_T1_T2, false, false), + Marginal(q_a, false, false), ) contingency = diff --git a/test/nodes/predefined/gamma_inverse_tests.jl b/test/nodes/predefined/gamma_inverse_tests.jl index 16529835c..8fd4020b5 100644 --- a/test/nodes/predefined/gamma_inverse_tests.jl +++ b/test/nodes/predefined/gamma_inverse_tests.jl @@ -10,9 +10,9 @@ q_θ = PointMass(1.0) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_α, false, false, nothing), - Marginal(q_θ, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_α, false, false), + Marginal(q_θ, false, false), ) @test score( @@ -29,9 +29,9 @@ q_θ = PointMass(42.0) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_α, false, false, nothing), - Marginal(q_θ, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_α, false, false), + Marginal(q_θ, false, false), ) @test score( diff --git a/test/nodes/predefined/gamma_mixture_tests.jl b/test/nodes/predefined/gamma_mixture_tests.jl index a70384e84..e753686b6 100644 --- a/test/nodes/predefined/gamma_mixture_tests.jl +++ b/test/nodes/predefined/gamma_mixture_tests.jl @@ -191,10 +191,10 @@ end q_b = (GammaShapeRate(1.5, 2.5), GammaShapeRate(3.5, 4.5)) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_switch, false, false, nothing), - ManyOf(map(q -> Marginal(q, false, false, nothing), q_a)), - ManyOf(map(q -> Marginal(q, false, false, nothing), q_b)), + Marginal(q_out, false, false), + Marginal(q_switch, false, false), + ManyOf(map(q -> Marginal(q, false, false), q_a)), + ManyOf(map(q -> Marginal(q, false, false), q_b)), ) # @average_energy GammaMixture (q_out::Any, q_switch::Any, q_a::ManyOf{N, Any}, q_b::ManyOf{N, GammaShapeRate}) @@ -204,20 +204,14 @@ end AverageEnergy(), GammaShapeRate, Val{(:out, :α, :β)}(), - map( - (q) -> Marginal(q, false, false, nothing), - (q_out, q_a[1], q_b[1]), - ), + map((q) -> Marginal(q, false, false), (q_out, q_a[1], q_b[1])), nothing, ) + z[2] * score( AverageEnergy(), GammaShapeRate, Val{(:out, :α, :β)}(), - map( - (q) -> Marginal(q, false, false, nothing), - (q_out, q_a[2], q_b[2]), - ), + map((q) -> Marginal(q, false, false), (q_out, q_a[2], q_b[2])), nothing, ) diff --git a/test/nodes/predefined/half_normal_tests.jl b/test/nodes/predefined/half_normal_tests.jl index 394f78b31..f5473cb2c 100644 --- a/test/nodes/predefined/half_normal_tests.jl +++ b/test/nodes/predefined/half_normal_tests.jl @@ -8,8 +8,7 @@ q_v = PointMass(2.0) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_v, false, false, nothing), + Marginal(q_out, false, false), Marginal(q_v, false, false) ) @test score( @@ -25,8 +24,7 @@ q_v = PointMass(2.0) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_v, false, false, nothing), + Marginal(q_out, false, false), Marginal(q_v, false, false) ) @test score( @@ -43,8 +41,7 @@ q_v = PointMass(2.0) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_v, false, false, nothing), + Marginal(q_out, false, false), Marginal(q_v, false, false) ) @test score( diff --git a/test/nodes/predefined/mixture_tests.jl b/test/nodes/predefined/mixture_tests.jl index d1deaa736..97967b2f2 100644 --- a/test/nodes/predefined/mixture_tests.jl +++ b/test/nodes/predefined/mixture_tests.jl @@ -21,7 +21,6 @@ interfaces, sdtype, factornode, - FactorNodeActivationOptions, activate! # Common interfaces and factorizations used by both test groups diff --git a/test/nodes/predefined/multinomial_polya_tests.jl b/test/nodes/predefined/multinomial_polya_tests.jl index 060f83223..362a3c6bd 100644 --- a/test/nodes/predefined/multinomial_polya_tests.jl +++ b/test/nodes/predefined/multinomial_polya_tests.jl @@ -17,9 +17,9 @@ MultinomialPolya, Val{(:x, :N, :ψ)}(), ( - Marginal(q_x, false, false, nothing), - Marginal(q_N, false, false, nothing), - Marginal(q_ψ, false, false, meta), + Marginal(q_x, false, false), + Marginal(q_N, false, false), + Marginal(q_ψ, false, false), ), meta, ) ≈ 104.19 atol = 0.1 @@ -35,9 +35,9 @@ MultinomialPolya, Val{(:x, :N, :ψ)}(), ( - Marginal(q_x, false, false, nothing), - Marginal(q_N, false, false, nothing), - Marginal(q_ψ, false, false, meta), + Marginal(q_x, false, false), + Marginal(q_N, false, false), + Marginal(q_ψ, false, false), ), meta, ) ≈ -101.72 atol = 0.1 @@ -53,9 +53,9 @@ MultinomialPolya, Val{(:x, :N, :ψ)}(), ( - Marginal(q_x, false, false, nothing), - Marginal(q_N, false, false, nothing), - Marginal(q_ψ, false, false, meta), + Marginal(q_x, false, false), + Marginal(q_N, false, false), + Marginal(q_ψ, false, false), ), meta, ) diff --git a/test/nodes/predefined/mv_normal_mean_covariance_tests.jl b/test/nodes/predefined/mv_normal_mean_covariance_tests.jl index 88f0b61a2..02b1aabdc 100644 --- a/test/nodes/predefined/mv_normal_mean_covariance_tests.jl +++ b/test/nodes/predefined/mv_normal_mean_covariance_tests.jl @@ -14,9 +14,9 @@ MvNormalWeightedMeanPrecision, ) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(q_Σ, false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(q_Σ, false, false), ) @test score( AverageEnergy(), @@ -62,9 +62,9 @@ ) marginals = ( - Marginal(convert(N1, q_out), false, false, nothing), - Marginal(convert(N2, q_μ), false, false, nothing), - Marginal(q_Σ, false, false, nothing), + Marginal(convert(N1, q_out), false, false), + Marginal(convert(N2, q_μ), false, false), + Marginal(q_Σ, false, false), ) @test score( AverageEnergy(), @@ -110,9 +110,9 @@ ) marginals = ( - Marginal(convert(N1, q_out), false, false, nothing), - Marginal(convert(N2, q_μ), false, false, nothing), - Marginal(q_Σ, false, false, nothing), + Marginal(convert(N1, q_out), false, false), + Marginal(convert(N2, q_μ), false, false), + Marginal(q_Σ, false, false), ) @test score( AverageEnergy(), @@ -152,8 +152,8 @@ MvNormalWeightedMeanPrecision, ) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(q_Σ, false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(q_Σ, false, false), ) @test score( AverageEnergy(), @@ -193,8 +193,8 @@ MvNormalWeightedMeanPrecision, ) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(q_Σ, false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(q_Σ, false, false), ) @test score( AverageEnergy(), diff --git a/test/nodes/predefined/mv_normal_mean_precision_tests.jl b/test/nodes/predefined/mv_normal_mean_precision_tests.jl index 678a615d0..bcf0f77cb 100644 --- a/test/nodes/predefined/mv_normal_mean_precision_tests.jl +++ b/test/nodes/predefined/mv_normal_mean_precision_tests.jl @@ -16,9 +16,9 @@ G in (Wishart,) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(convert(G, q_Λ), false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(convert(G, q_Λ), false, false), ) @test score( AverageEnergy(), @@ -63,9 +63,9 @@ G in (Wishart,) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(convert(G, q_Λ), false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(convert(G, q_Λ), false, false), ) @test score( AverageEnergy(), @@ -113,9 +113,9 @@ G in (Wishart,) marginals = ( - Marginal(convert(N1, q_out), false, false, nothing), - Marginal(convert(N2, q_μ), false, false, nothing), - Marginal(convert(G, q_Λ), false, false, nothing), + Marginal(convert(N1, q_out), false, false), + Marginal(convert(N2, q_μ), false, false), + Marginal(convert(G, q_Λ), false, false), ) @test score( AverageEnergy(), @@ -166,9 +166,9 @@ G in (Wishart,) marginals = ( - Marginal(convert(N1, q_out), false, false, nothing), - Marginal(convert(N2, q_μ), false, false, nothing), - Marginal(convert(G, q_Λ), false, false, nothing), + Marginal(convert(N1, q_out), false, false), + Marginal(convert(N2, q_μ), false, false), + Marginal(convert(G, q_Λ), false, false), ) @test score( AverageEnergy(), @@ -217,9 +217,9 @@ ) marginals = ( - Marginal(convert(N1, q_out), false, false, nothing), - Marginal(convert(N2, q_μ), false, false, nothing), - Marginal(q_Λ, false, false, nothing), + Marginal(convert(N1, q_out), false, false), + Marginal(convert(N2, q_μ), false, false), + Marginal(q_Λ, false, false), ) @test score( AverageEnergy(), @@ -259,8 +259,8 @@ MvNormalWeightedMeanPrecision, ) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(q_Λ, false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(q_Λ, false, false), ) @test score( AverageEnergy(), @@ -300,8 +300,8 @@ MvNormalWeightedMeanPrecision, ) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(q_Λ, false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(q_Λ, false, false), ) @test score( AverageEnergy(), @@ -344,8 +344,8 @@ G in (Wishart,) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(convert(G, q_Λ), false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(convert(G, q_Λ), false, false), ) @test score( AverageEnergy(), @@ -388,8 +388,8 @@ G in (Wishart,) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(convert(G, q_Λ), false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(convert(G, q_Λ), false, false), ) @test score( AverageEnergy(), diff --git a/test/nodes/predefined/mv_normal_mean_scale_matrix_precision_tests.jl b/test/nodes/predefined/mv_normal_mean_scale_matrix_precision_tests.jl index 68b5b2365..00381cd16 100644 --- a/test/nodes/predefined/mv_normal_mean_scale_matrix_precision_tests.jl +++ b/test/nodes/predefined/mv_normal_mean_scale_matrix_precision_tests.jl @@ -21,10 +21,10 @@ M in (Wishart,) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(convert(g, q_γ), false, false, nothing), - Marginal(convert(M, q_G), false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(convert(g, q_γ), false, false), + Marginal(convert(M, q_G), false, false), ) @test score( AverageEnergy(), @@ -71,10 +71,10 @@ M in (Wishart,) marginals = ( - Marginal(convert(N1, q_out), false, false, nothing), - Marginal(convert(N2, q_μ), false, false, nothing), - Marginal(convert(g, q_γ), false, false, nothing), - Marginal(convert(M, q_G), false, false, nothing), + Marginal(convert(N1, q_out), false, false), + Marginal(convert(N2, q_μ), false, false), + Marginal(convert(g, q_γ), false, false), + Marginal(convert(M, q_G), false, false), ) @test score( AverageEnergy(), @@ -116,9 +116,9 @@ MvNormalWeightedMeanPrecision, ) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(q_γ, false, false, nothing), - Marginal(q_G, false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(q_γ, false, false), + Marginal(q_G, false, false), ) @test score( AverageEnergy(), @@ -161,9 +161,9 @@ MvNormalWeightedMeanPrecision, ) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(q_γ, false, false, nothing), - Marginal(q_G, false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(q_γ, false, false), + Marginal(q_G, false, false), ) @test score( AverageEnergy(), diff --git a/test/nodes/predefined/mv_normal_mean_scale_precision_tests.jl b/test/nodes/predefined/mv_normal_mean_scale_precision_tests.jl index ff45821fc..74c4512d2 100644 --- a/test/nodes/predefined/mv_normal_mean_scale_precision_tests.jl +++ b/test/nodes/predefined/mv_normal_mean_scale_precision_tests.jl @@ -16,9 +16,9 @@ g in (Gamma,) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(convert(g, q_γ), false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(convert(g, q_γ), false, false), ) @test score( AverageEnergy(), @@ -60,9 +60,9 @@ g in (Gamma,) marginals = ( - Marginal(convert(N1, q_out), false, false, nothing), - Marginal(convert(N2, q_μ), false, false, nothing), - Marginal(convert(g, q_γ), false, false, nothing), + Marginal(convert(N1, q_out), false, false), + Marginal(convert(N2, q_μ), false, false), + Marginal(convert(g, q_γ), false, false), ) @test score( AverageEnergy(), @@ -97,8 +97,8 @@ MvNormalWeightedMeanPrecision, ) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(q_γ, false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(q_γ, false, false), ) @test score( AverageEnergy(), @@ -133,8 +133,8 @@ MvNormalWeightedMeanPrecision, ) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(q_γ, false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(q_γ, false, false), ) @test score( AverageEnergy(), diff --git a/test/nodes/predefined/mv_normal_weightedmean_precision_tests.jl b/test/nodes/predefined/mv_normal_weightedmean_precision_tests.jl index aa4558154..fb51d77e0 100644 --- a/test/nodes/predefined/mv_normal_weightedmean_precision_tests.jl +++ b/test/nodes/predefined/mv_normal_weightedmean_precision_tests.jl @@ -22,14 +22,14 @@ MvNormalWeightedMeanPrecision, ) marginalsξ = ( - Marginal(q_out, false, false, nothing), - Marginal(q_ξ, false, false, nothing), - Marginal(q_Λ, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_ξ, false, false), + Marginal(q_Λ, false, false), ) marginalsμ = ( - Marginal(q_out, false, false, nothing), - Marginal(q_μ, false, false, nothing), - Marginal(q_Σ, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_μ, false, false), + Marginal(q_Σ, false, false), ) @test score( AverageEnergy(), diff --git a/test/nodes/predefined/normal_mean_precision_tests.jl b/test/nodes/predefined/normal_mean_precision_tests.jl index 9d0b7ffe2..fb4267b7a 100644 --- a/test/nodes/predefined/normal_mean_precision_tests.jl +++ b/test/nodes/predefined/normal_mean_precision_tests.jl @@ -16,9 +16,9 @@ G in (GammaShapeRate, GammaShapeScale) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(convert(G, q_τ), false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(convert(G, q_τ), false, false), ) @test score( AverageEnergy(), @@ -35,8 +35,7 @@ q_τ = GammaShapeRate(1.5, 1.5) marginals = ( - Marginal(q_out_μ, false, false, nothing), - Marginal(q_τ, false, false, nothing), + Marginal(q_out_μ, false, false), Marginal(q_τ, false, false) ) @test score( AverageEnergy(), @@ -54,8 +53,7 @@ ) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_μ_τ, false, false, nothing), + Marginal(q_out, false, false), Marginal(q_μ_τ, false, false) ) @test score( AverageEnergy(), @@ -73,7 +71,7 @@ τ = GammaShapeRate(1.5, 1.5), ) - marginals = (Marginal(q_out_μ_τ, false, false, nothing),) + marginals = (Marginal(q_out_μ_τ, false, false),) @test score( AverageEnergy(), NormalMeanPrecision, @@ -96,9 +94,9 @@ G in (GammaShapeRate, GammaShapeScale) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(convert(G, q_τ), false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(convert(G, q_τ), false, false), ) @test score( AverageEnergy(), @@ -123,9 +121,9 @@ G in (GammaShapeRate, GammaShapeScale) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(convert(G, q_τ), false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(convert(G, q_τ), false, false), ) @test score( AverageEnergy(), @@ -150,9 +148,9 @@ G in (GammaShapeRate, GammaShapeScale) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(convert(G, q_τ), false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(convert(G, q_τ), false, false), ) @test score( AverageEnergy(), @@ -176,8 +174,8 @@ G in (GammaShapeRate, GammaShapeScale) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(convert(G, q_τ), false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(convert(G, q_τ), false, false), ) @test score( AverageEnergy(), @@ -201,8 +199,8 @@ G in (GammaShapeRate, GammaShapeScale) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(convert(G, q_τ), false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(convert(G, q_τ), false, false), ) @test score( AverageEnergy(), @@ -219,9 +217,9 @@ q_b = PointMass(1.0) q_c = PointMass(1.0) marginals = ( - Marginal(q_a, false, false, nothing), - Marginal(q_b, false, false, nothing), - Marginal(q_c, false, false, nothing), + Marginal(q_a, false, false), + Marginal(q_b, false, false), + Marginal(q_c, false, false), ) meta = 1 @test_throws r"Cannot compute Average Energy for the .*NormalMeanPrecision node, the method does not exist for the provided marginals." score( diff --git a/test/nodes/predefined/normal_mean_variance_tests.jl b/test/nodes/predefined/normal_mean_variance_tests.jl index 7ba7b628f..5b7a180b0 100644 --- a/test/nodes/predefined/normal_mean_variance_tests.jl +++ b/test/nodes/predefined/normal_mean_variance_tests.jl @@ -16,9 +16,9 @@ G in (GammaShapeRate, GammaShapeScale) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(convert(G, q_τ), false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(convert(G, q_τ), false, false), ) @test score( AverageEnergy(), @@ -43,9 +43,9 @@ G in (GammaShapeRate, GammaShapeScale) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(convert(N, q_μ), false, false, nothing), - Marginal(convert(G, q_v), false, false, nothing), + Marginal(q_out, false, false), + Marginal(convert(N, q_μ), false, false), + Marginal(convert(G, q_v), false, false), ) @test score( AverageEnergy(), @@ -75,8 +75,8 @@ G in (GammaShapeRate, GammaShapeScale) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(convert(G, q_v), false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(convert(G, q_v), false, false), ) @test score( AverageEnergy(), @@ -106,8 +106,8 @@ G in (GammaShapeRate, GammaShapeScale) marginals = ( - Marginal(convert(N, q_out_μ), false, false, nothing), - Marginal(convert(G, q_v), false, false, nothing), + Marginal(convert(N, q_out_μ), false, false), + Marginal(convert(G, q_v), false, false), ) @test score( AverageEnergy(), diff --git a/test/nodes/predefined/normal_mixture_tests.jl b/test/nodes/predefined/normal_mixture_tests.jl index 860f69cc7..ffc8aca45 100644 --- a/test/nodes/predefined/normal_mixture_tests.jl +++ b/test/nodes/predefined/normal_mixture_tests.jl @@ -13,10 +13,10 @@ q_p = (GammaShapeRate(2.0, 3.0), GammaShapeRate(4.0, 5.0)) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_switch, false, false, nothing), - ManyOf(map(q_m_ -> Marginal(q_m_, false, false, nothing), q_m)), - ManyOf(map(q_p_ -> Marginal(q_p_, false, false, nothing), q_p)), + Marginal(q_out, false, false), + Marginal(q_switch, false, false), + ManyOf(map(q_m_ -> Marginal(q_m_, false, false), q_m)), + ManyOf(map(q_p_ -> Marginal(q_p_, false, false), q_p)), ) ref_val = @@ -25,7 +25,7 @@ NormalMeanPrecision, Val{(:out, :μ, :τ)}(), map( - (q) -> Marginal(q, false, false, nothing), + (q) -> Marginal(q, false, false), (q_out, q_m[1], q_p[1]), ), nothing, @@ -35,7 +35,7 @@ NormalMeanPrecision, Val{(:out, :μ, :τ)}(), map( - (q) -> Marginal(q, false, false, nothing), + (q) -> Marginal(q, false, false), (q_out, q_m[2], q_p[2]), ), nothing, @@ -56,10 +56,10 @@ q_p = (GammaShapeRate(2.0, 3.0), GammaShapeRate(1.0, 5.0)) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_switch, false, false, nothing), - ManyOf(map(q_m_ -> Marginal(q_m_, false, false, nothing), q_m)), - ManyOf(map(q_p_ -> Marginal(q_p_, false, false, nothing), q_p)), + Marginal(q_out, false, false), + Marginal(q_switch, false, false), + ManyOf(map(q_m_ -> Marginal(q_m_, false, false), q_m)), + ManyOf(map(q_p_ -> Marginal(q_p_, false, false), q_p)), ) ref_val = @@ -68,7 +68,7 @@ NormalMeanPrecision, Val{(:out, :μ, :τ)}(), map( - (q) -> Marginal(q, false, false, nothing), + (q) -> Marginal(q, false, false), (q_out, q_m[1], q_p[1]), ), nothing, @@ -78,7 +78,7 @@ NormalMeanPrecision, Val{(:out, :μ, :τ)}(), map( - (q) -> Marginal(q, false, false, nothing), + (q) -> Marginal(q, false, false), (q_out, q_m[2], q_p[2]), ), nothing, @@ -99,10 +99,10 @@ q_p = (GammaShapeRate(3.0, 3.0), GammaShapeRate(4.0, 5.0)) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_switch, false, false, nothing), - ManyOf(map(q_m_ -> Marginal(q_m_, false, false, nothing), q_m)), - ManyOf(map(q_p_ -> Marginal(q_p_, false, false, nothing), q_p)), + Marginal(q_out, false, false), + Marginal(q_switch, false, false), + ManyOf(map(q_m_ -> Marginal(q_m_, false, false), q_m)), + ManyOf(map(q_p_ -> Marginal(q_p_, false, false), q_p)), ) ref_val = @@ -111,7 +111,7 @@ NormalMeanPrecision, Val{(:out, :μ, :τ)}(), map( - (q) -> Marginal(q, false, false, nothing), + (q) -> Marginal(q, false, false), (q_out, q_m[1], q_p[1]), ), nothing, @@ -121,7 +121,7 @@ NormalMeanPrecision, Val{(:out, :μ, :τ)}(), map( - (q) -> Marginal(q, false, false, nothing), + (q) -> Marginal(q, false, false), (q_out, q_m[2], q_p[2]), ), nothing, @@ -148,10 +148,10 @@ ) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_switch, false, false, nothing), - ManyOf(map(q_m_ -> Marginal(q_m_, false, false, nothing), q_m)), - ManyOf(map(q_p_ -> Marginal(q_p_, false, false, nothing), q_p)), + Marginal(q_out, false, false), + Marginal(q_switch, false, false), + ManyOf(map(q_m_ -> Marginal(q_m_, false, false), q_m)), + ManyOf(map(q_p_ -> Marginal(q_p_, false, false), q_p)), ) ref_val = @@ -160,7 +160,7 @@ MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}(), map( - (q) -> Marginal(q, false, false, nothing), + (q) -> Marginal(q, false, false), (q_out, q_m[1], q_p[1]), ), nothing, @@ -170,7 +170,7 @@ MvNormalMeanPrecision, Val{(:out, :μ, :Λ)}(), map( - (q) -> Marginal(q, false, false, nothing), + (q) -> Marginal(q, false, false), (q_out, q_m[2], q_p[2]), ), nothing, diff --git a/test/nodes/predefined/poisson_tests.jl b/test/nodes/predefined/poisson_tests.jl index ebec70c02..3808d22a7 100644 --- a/test/nodes/predefined/poisson_tests.jl +++ b/test/nodes/predefined/poisson_tests.jl @@ -10,8 +10,8 @@ Poisson, Val{(:out, :l)}(), ( - Marginal(PointMass(k), false, false, nothing), - Marginal(PointMass(l), false, false, nothing), + Marginal(PointMass(k), false, false), + Marginal(PointMass(l), false, false), ), nothing, ), @@ -27,8 +27,8 @@ Poisson, Val{(:out, :l)}(), ( - Marginal(Poisson(k), false, false, nothing), - Marginal(PointMass(k), false, false, nothing), + Marginal(Poisson(k), false, false), + Marginal(PointMass(k), false, false), ), nothing, ), @@ -44,8 +44,8 @@ Poisson, Val{(:out, :l)}(), ( - Marginal(Poisson(k), false, false, nothing), - Marginal(PointMass(k), false, false, nothing), + Marginal(Poisson(k), false, false), + Marginal(PointMass(k), false, false), ), nothing, ), diff --git a/test/nodes/predefined/probit_tests.jl b/test/nodes/predefined/probit_tests.jl index 319f40fea..4af31aec1 100644 --- a/test/nodes/predefined/probit_tests.jl +++ b/test/nodes/predefined/probit_tests.jl @@ -8,8 +8,8 @@ Probit, Val{(:out, :in)}(), ( - Marginal(Bernoulli(1), false, false, nothing), - Marginal(NormalMeanVariance(0.0, 1.0), false, false, nothing), + Marginal(Bernoulli(1), false, false), + Marginal(NormalMeanVariance(0.0, 1.0), false, false), ), ProbitMeta(), ) ≈ 1.0 @@ -19,8 +19,8 @@ Probit, Val{(:out, :in)}(), ( - Marginal(PointMass(1), false, false, nothing), - Marginal(NormalMeanVariance(0.0, 1.0), false, false, nothing), + Marginal(PointMass(1), false, false), + Marginal(NormalMeanVariance(0.0, 1.0), false, false), ), ProbitMeta(100), ) ≈ 1.0 @@ -31,10 +31,8 @@ Probit, Val{(:out, :in)}(), ( - Marginal(Bernoulli(k), false, false, nothing), - Marginal( - NormalMeanVariance(0.0, 1.0), false, false, nothing - ), + Marginal(Bernoulli(k), false, false), + Marginal(NormalMeanVariance(0.0, 1.0), false, false), ), ProbitMeta(100), ) ≈ 1.0 diff --git a/test/nodes/predefined/softdot_tests.jl b/test/nodes/predefined/softdot_tests.jl index a53a69b3d..bf62d53bc 100644 --- a/test/nodes/predefined/softdot_tests.jl +++ b/test/nodes/predefined/softdot_tests.jl @@ -9,10 +9,10 @@ q_x = NormalMeanVariance(5.0, 9.0) q_γ = GammaShapeRate(3 / 2, 4242 / 2) marginals = ( - Marginal(q_y, false, false, nothing), - Marginal(q_θ, false, false, nothing), - Marginal(q_x, false, false, nothing), - Marginal(q_γ, false, false, nothing), + Marginal(q_y, false, false), + Marginal(q_θ, false, false), + Marginal(q_x, false, false), + Marginal(q_γ, false, false), ) @test score( @@ -30,10 +30,10 @@ q_x = MvNormalMeanCovariance([5.0, 9.0], [11.0 13.0; 17.0 19.0]) q_γ = GammaShapeRate(3 / 2, 191032 / 2) marginals = ( - Marginal(q_y, false, false, nothing), - Marginal(q_θ, false, false, nothing), - Marginal(q_x, false, false, nothing), - Marginal(q_γ, false, false, nothing), + Marginal(q_y, false, false), + Marginal(q_θ, false, false), + Marginal(q_x, false, false), + Marginal(q_γ, false, false), ) @test score( @@ -51,9 +51,9 @@ q_γ = GammaShapeRate(2.0, 3.0) marginals = ( - Marginal(q_y_x, false, false, nothing), - Marginal(q_θ, false, false, nothing), - Marginal(q_γ, false, false, nothing), + Marginal(q_y_x, false, false), + Marginal(q_θ, false, false), + Marginal(q_γ, false, false), ) @test score( AverageEnergy(), @@ -70,9 +70,9 @@ q_γ = GammaShapeRate(2.0, 3.0) marginals = ( - Marginal(q_y_x, false, false, nothing), - Marginal(q_θ, false, false, nothing), - Marginal(q_γ, false, false, nothing), + Marginal(q_y_x, false, false), + Marginal(q_θ, false, false), + Marginal(q_γ, false, false), ) @test score( AverageEnergy(), diff --git a/test/nodes/predefined/uniform_tests.jl b/test/nodes/predefined/uniform_tests.jl index 691834398..e6031c551 100644 --- a/test/nodes/predefined/uniform_tests.jl +++ b/test/nodes/predefined/uniform_tests.jl @@ -9,9 +9,9 @@ Uniform, Val{(:out, :a, :b)}(), ( - Marginal(Beta(α, β), false, false, nothing), - Marginal(PointMass(a), false, false, nothing), - Marginal(PointMass(b), false, false, nothing), + Marginal(Beta(α, β), false, false), + Marginal(PointMass(a), false, false), + Marginal(PointMass(b), false, false), ), nothing, ) == 0.0 diff --git a/test/nodes/predefined/wishart_inverse_tests.jl b/test/nodes/predefined/wishart_inverse_tests.jl index 9a392131f..e95f001ac 100644 --- a/test/nodes/predefined/wishart_inverse_tests.jl +++ b/test/nodes/predefined/wishart_inverse_tests.jl @@ -11,9 +11,9 @@ q_S = PointMass([2.0 0.0; 0.0 2.0]) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_ν, false, false, nothing), - Marginal(q_S, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_ν, false, false), + Marginal(q_S, false, false), ) @test score( AverageEnergy(), @@ -32,9 +32,9 @@ q_S = PointMass(S) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_ν, false, false, nothing), - Marginal(q_S, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_ν, false, false), + Marginal(q_S, false, false), ) @test score( AverageEnergy(), diff --git a/test/nodes/predefined/wishart_tests.jl b/test/nodes/predefined/wishart_tests.jl index d21ade1eb..2eb565a4e 100644 --- a/test/nodes/predefined/wishart_tests.jl +++ b/test/nodes/predefined/wishart_tests.jl @@ -11,9 +11,9 @@ q_S = PointMass([2.0 0.0; 0.0 2.0]) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_ν, false, false, nothing), - Marginal(q_S, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_ν, false, false), + Marginal(q_S, false, false), ) @test score( @@ -38,9 +38,9 @@ q_S = PointMass(S) marginals = ( - Marginal(q_out, false, false, nothing), - Marginal(q_ν, false, false, nothing), - Marginal(q_S, false, false, nothing), + Marginal(q_out, false, false), + Marginal(q_ν, false, false), + Marginal(q_S, false, false), ) @test score( diff --git a/test/pipeline/logger_tests.jl b/test/pipeline/logger_tests.jl deleted file mode 100644 index 6e17c0d9b..000000000 --- a/test/pipeline/logger_tests.jl +++ /dev/null @@ -1,61 +0,0 @@ - -@testitem "LoggerPipelineStage" begin - using ReactiveMP - using Distributions - using Rocket - - import ReactiveMP: tag, factornode, getinterfaces - - struct DummyNodeForLoggerTests end - - @node DummyNodeForLoggerTests Stochastic [out, x, y] - - # In real applications the stream should be a stream of messages - # For testing purposes it does not really matter though - stream = Subject(String) - node = factornode( - DummyNodeForLoggerTests, - [(:out, randomvar()), (:x, randomvar()), (:y, randomvar())], - ((1, 2, 3),), - ) - - @testset "no prefix" begin - io = IOBuffer() - pipeline = LoggerPipelineStage(io) - modified_stream = apply_pipeline_stage( - pipeline, node, tag(first(getinterfaces(node))), stream - ) - subscription = subscribe!(modified_stream, void()) - - next!(stream, "hello") - - logged_str = String(take!(io)) - - @test contains(logged_str, "Log") # default prefix - @test contains(logged_str, "DummyNodeForLoggerTests") - @test contains(logged_str, "out") - @test contains(logged_str, "hello") - - unsubscribe!(subscription) - end - - @testset "with custom prefix" begin - io = IOBuffer() - pipeline = LoggerPipelineStage(io, "custom_prefix") - modified_stream = apply_pipeline_stage( - pipeline, node, tag(first(getinterfaces(node))), stream - ) - subscription = subscribe!(modified_stream, void()) - - next!(stream, "hello") - - logged_str = String(take!(io)) - - @test contains(logged_str, "custom_prefix") - @test contains(logged_str, "DummyNodeForLoggerTests") - @test contains(logged_str, "out") - @test contains(logged_str, "hello") - - unsubscribe!(subscription) - end -end diff --git a/test/rule_tests.jl b/test/rule_tests.jl index 0533b82c9..e4fcb1ba8 100644 --- a/test/rule_tests.jl +++ b/test/rule_tests.jl @@ -650,10 +650,8 @@ @test names == :(Val{(:out, :mean)}()) @test values == :( - ReactiveMP.Message(PointMass(1.0), false, false, nothing), - ReactiveMP.Message( - NormalMeanPrecision(0.0, 1.0), false, false, nothing - ), + ReactiveMP.Message(PointMass(1.0), false, false), + ReactiveMP.Message(NormalMeanPrecision(0.0, 1.0), false, false), ) names, values = call_rule_macro_parse_fn_args( @@ -682,7 +680,7 @@ @test names == :(Val{(:mean,)}()) @test values == :(( ReactiveMP.Marginal( - NormalMeanPrecision(0.0, 1.0), false, false, nothing + NormalMeanPrecision(0.0, 1.0), false, false ), )) end @@ -690,12 +688,8 @@ @testset "Error utilities" begin @testset "rule_method_error" begin - as_vague_msg(::Type{T}) where {T} = Message( - vague(T), false, false, nothing - ) - as_vague_mrg(::Type{T}) where {T} = Marginal( - vague(T), false, false, nothing - ) + as_vague_msg(::Type{T}) where {T} = Message(vague(T), false, false) + as_vague_mrg(::Type{T}) where {T} = Marginal(vague(T), false, false) let err = ReactiveMP.RuleMethodError( @@ -808,14 +802,7 @@ nothing, nothing, Val{(:out_μ,)}(), - ( - Marginal( - vague(MvNormalMeanPrecision, 2), - false, - false, - nothing, - ), - ), + (Marginal(vague(MvNormalMeanPrecision, 2), false, false),), 1.0, nothing, nothing, @@ -843,14 +830,7 @@ as_vague_msg(NormalMeanVariance), ), Val{(:out_μ,)}(), - ( - Marginal( - vague(MvNormalMeanPrecision, 2), - false, - false, - nothing, - ), - ), + (Marginal(vague(MvNormalMeanPrecision, 2), false, false),), 1.0, nothing, nothing, @@ -874,11 +854,11 @@ Marginalisation(), Val{(:out, :b)}(), ( - Message(PointMass(1), false, false, nothing), - Message(PointMass, false, false, nothing), + Message(PointMass(1), false, false), + Message(PointMass, false, false), ), Val{(:out_b,)}(), - (Marginal(PointMass(1), false, false, nothing),), + (Marginal(PointMass(1), false, false),), 1.0, nothing, nothing, @@ -905,11 +885,11 @@ Marginalisation(), Val{(:out, :b)}(), ( - Message(PointMass(1), false, false, nothing), - Message(PointMass(1), false, false, nothing), + Message(PointMass(1), false, false), + Message(PointMass(1), false, false), ), Val{(:out_b,)}(), - (Marginal(PointMass(1), false, false, nothing),), + (Marginal(PointMass(1), false, false),), 1.0, nothing, nothing, @@ -940,9 +920,9 @@ Val{:a}(), Marginalisation(), Val{(:out,)}(), - (Message(PointMass(1), false, false, nothing),), + (Message(PointMass(1), false, false),), Val{(:a,)}(), - (Marginal(PointMass(1), false, false, nothing),), + (Marginal(PointMass(1), false, false),), 1.0, nothing, nothing, @@ -979,7 +959,7 @@ nothing, nothing, Val{(:a,)}(), - (Marginal(PointMass(1), false, false, nothing),), + (Marginal(PointMass(1), false, false),), "meta", nothing, nothing, @@ -1004,12 +984,8 @@ end @testset "marginalrule_method_error" begin - as_vague_msg(::Type{T}) where {T} = Message( - vague(T), false, false, nothing - ) - as_vague_mrg(::Type{T}) where {T} = Marginal( - vague(T), false, false, nothing - ) + as_vague_msg(::Type{T}) where {T} = Message(vague(T), false, false) + as_vague_mrg(::Type{T}) where {T} = Marginal(vague(T), false, false) let err = ReactiveMP.MarginalRuleMethodError( @@ -1109,14 +1085,7 @@ nothing, nothing, Val{(:out_μ,)}(), - ( - Marginal( - vague(MvNormalMeanPrecision, 2), - false, - false, - nothing, - ), - ), + (Marginal(vague(MvNormalMeanPrecision, 2), false, false),), 1.0, nothing, ) @@ -1141,14 +1110,7 @@ as_vague_msg(NormalMeanVariance), ), Val{(:out_μ,)}(), - ( - Marginal( - vague(MvNormalMeanPrecision, 2), - false, - false, - nothing, - ), - ), + (Marginal(vague(MvNormalMeanPrecision, 2), false, false),), 1.0, nothing, ) @@ -1292,25 +1254,20 @@ )) end - @testset "Check the `return_addons` option" begin - # Enable LogScale addon - dist_and_addons = @call_rule [return_addons = true] Bernoulli( - :out, Marginalisation - ) (m_p = Beta(1, 2), addons = (AddonLogScale(),)) + @testset "Check the `annotations` option" begin + import ReactiveMP: AnnotationDict, has_annotation - @test dist_and_addons isa Tuple - @test length(dist_and_addons) === 2 - @test dist_and_addons[1] isa Bernoulli - @test dist_and_addons[2] isa Tuple{AddonLogScale} + # Pass an AnnotationDict to capture annotations written by the rule + ann = AnnotationDict() + dist = @call_rule Bernoulli(:out, Marginalisation) ( + m_p = Beta(1, 2), annotations = ann + ) - # Without addons but with the option - dist_and_nothing = @call_rule [return_addons = true] Bernoulli( - :out, Marginalisation - ) (m_p = Beta(1, 2),) + @test dist isa Bernoulli + @test has_annotation(ann, :logscale) - @test dist_and_nothing isa Tuple - @test length(dist_and_nothing) === 2 - @test dist_and_nothing[1] isa Bernoulli - @test dist_and_nothing[2] isa Nothing + # Without annotations keyword — rule still works, annotations are discarded + dist2 = @call_rule Bernoulli(:out, Marginalisation) (m_p = Beta(1, 2),) + @test dist2 isa Bernoulli end end diff --git a/test/testutilities.jl b/test/testutilities.jl index d45a2a458..3b99a5c65 100644 --- a/test/testutilities.jl +++ b/test/testutilities.jl @@ -49,5 +49,5 @@ function check_stream_not_updated(stream) return check_stream_not_updated(() -> nothing, stream) end -msg(value) = Message(value, false, false, nothing) -mgl(value) = Marginal(value, false, false, nothing) +msg(value) = Message(value, false, false) +mgl(value) = Marginal(value, false, false) diff --git a/test/variables/constant_tests.jl b/test/variables/constant_tests.jl index 968bcf16a..576a82211 100644 --- a/test/variables/constant_tests.jl +++ b/test/variables/constant_tests.jl @@ -1,25 +1,33 @@ @testitem "ConstVariable: uninitialized" begin - import ReactiveMP: messageout, messagein + import ReactiveMP: + get_stream_of_outbound_messages, get_stream_of_inbound_messages # Should throw if not initialised properly let var = constvar(1) for i in 1:10 - @test messageout(var, 1) === messageout(var, i) - @test_throws ErrorException messagein(var, i) + @test get_stream_of_outbound_messages(var, 1) === + get_stream_of_outbound_messages(var, i) + @test_throws ErrorException get_stream_of_inbound_messages(var, i) end end end -@testitem "ConstVariable: getmessagein!" begin - import ReactiveMP: MessageObservable, create_messagein!, messagein, degree +@testitem "ConstVariable: get_stream_of_inbound_messages" begin + import ReactiveMP: + MessageObservable, + create_new_stream_of_inbound_messages!, + get_stream_of_inbound_messages, + degree # Test for different degrees `d` for d in 1:5:100 let var = constvar(1) for i in 1:d - messagein, index = create_messagein!(var) - @test messagein isa MessageObservable + new_stream_of_inbound_messages, index = create_new_stream_of_inbound_messages!( + var + ) + @test new_stream_of_inbound_messages isa MessageObservable @test index === 1 @test degree(var) === i end @@ -28,18 +36,17 @@ end end end -@testitem "ConstVariable: getmarginal" begin +@testitem "ConstVariable: get_stream_of_marginals" begin using BayesBase import ReactiveMP: MessageObservable, - create_messagein!, - messagein, degree, activate!, connect!, DataVariableActivationOptions, - messageout + get_stream_of_outbound_messages, + get_stream_of_marginals include("../testutilities.jl") @@ -47,9 +54,10 @@ end for d in 1:5:100, constant in rand(10) let var = constvar(constant) marginal_expected = mgl(PointMass(constant)) - marginal_result = check_stream_updated_once(getmarginal(var)) do - nothing - end + marginal_result = + check_stream_updated_once(get_stream_of_marginals(var)) do + nothing + end @test getdata(marginal_result) === getdata(marginal_expected) @test getdata(marginal_result) === PointMass(constant) diff --git a/test/variables/data_tests.jl b/test/variables/data_tests.jl index ef1f48276..114f41ea8 100644 --- a/test/variables/data_tests.jl +++ b/test/variables/data_tests.jl @@ -1,25 +1,33 @@ @testitem "DataVariable: uninitialized" begin - import ReactiveMP: messageout, messagein + import ReactiveMP: + get_stream_of_outbound_messages, get_stream_of_inbound_messages # Should throw if not initialised properly let var = datavar() for i in 1:10 - @test messageout(var, 1) === messageout(var, i) - @test_throws BoundsError messagein(var, i) + @test get_stream_of_outbound_messages(var, 1) === + get_stream_of_outbound_messages(var, i) + @test_throws BoundsError get_stream_of_inbound_messages(var, i) end end end -@testitem "DataVariable: getmessagein!" begin - import ReactiveMP: MessageObservable, create_messagein!, messagein, degree +@testitem "DataVariable: get_stream_of_inbound_messages" begin + import ReactiveMP: + MessageObservable, + create_new_stream_of_inbound_messages!, + get_stream_of_inbound_messages, + degree # Test for different degrees `d` for d in 1:5:100 let var = datavar() for i in 1:d - messagein, index = create_messagein!(var) - @test messagein isa MessageObservable + new_stream_of_inbound_message, index = create_new_stream_of_inbound_messages!( + var + ) + @test new_stream_of_inbound_message isa MessageObservable @test index === i @test degree(var) === i end @@ -33,21 +41,23 @@ end import ReactiveMP: MessageObservable, - create_messagein!, - messagein, + create_new_stream_of_inbound_messages!, + get_stream_of_inbound_messages, degree, activate!, connect!, + new_observation!, DataVariableActivationOptions, - messageout + get_stream_of_outbound_messages, + get_stream_of_marginals include("../testutilities.jl") for d in 1:5:100 let var = datavar() - messageins = map(1:d) do _ + new_streams_of_inbound_messages = map(1:d) do _ s = Subject(AbstractMessage) - m, i = create_messagein!(var) + m, i = create_new_stream_of_inbound_messages!(var) connect!(m, s) return s end @@ -59,18 +69,21 @@ end messages = map(msg, rand(d)) - @test check_stream_not_updated(getmarginal(var)) do - foreach(zip(messageins, messages)) do (messagein, message) - next!(messagein, message) + @test check_stream_not_updated(get_stream_of_marginals(var)) do + foreach( + zip(new_streams_of_inbound_messages, messages) + ) do (new_stream_of_inbound_messages, message) + next!(new_stream_of_inbound_messages, message) end end data_point = rand() marginal_expected = mgl(PointMass(data_point)) - marginal_result = check_stream_updated_once(getmarginal(var)) do - update!(var, data_point) - end + marginal_result = + check_stream_updated_once(get_stream_of_marginals(var)) do + new_observation!(var, data_point) + end @test getdata(marginal_result) === getdata(marginal_expected) @test getdata(marginal_result) === PointMass(data_point) @@ -81,7 +94,12 @@ end @testitem "DataVariable: linked variable" begin using BayesBase import ReactiveMP: - DataVariable, DataVariableActivationOptions, activate!, messageout + DataVariable, + DataVariableActivationOptions, + activate!, + get_stream_of_outbound_messages, + get_stream_of_marginals, + new_observation! include("../testutilities.jl") @@ -91,9 +109,11 @@ end true, true, fn, (val1, val2) ) activate!(var, options) - marginal = check_stream_updated_once(getmarginal(var)) + marginal = check_stream_updated_once(get_stream_of_marginals(var)) @test getdata(marginal) === PointMass(fn(val1, val2)) - message = check_stream_updated_once(messageout(var, 1)) + message = check_stream_updated_once( + get_stream_of_outbound_messages(var, 1) + ) @test getdata(message) === PointMass(fn(val1, val2)) end @@ -103,7 +123,7 @@ end true, true, fn, (val1, val2) ) activate!(var, options) - marginal = check_stream_updated_once(getmarginal(var)) + marginal = check_stream_updated_once(get_stream_of_marginals(var)) @test getdata(marginal) === PointMass(fn(val1, val2)) end @@ -113,7 +133,9 @@ end true, true, fn, (val1, val2) ) activate!(var, options) - message = check_stream_updated_once(messageout(var, 1)) + message = check_stream_updated_once( + get_stream_of_outbound_messages(var, 1) + ) @test getdata(message) === PointMass(fn(val1, val2)) end @@ -129,13 +151,16 @@ end true, true, fn, (var1, val2) ) activate!(var, options) - @test check_stream_not_updated(getmarginal(var)) + @test check_stream_not_updated(get_stream_of_marginals(var)) - marginal = check_stream_updated_once(getmarginal(var)) do - update!(var1, val1) - end + marginal = + check_stream_updated_once(get_stream_of_marginals(var)) do + new_observation!(var1, val1) + end @test getdata(marginal) === PointMass(fn(val1, val2)) - message = check_stream_updated_once(messageout(var, 1)) + message = check_stream_updated_once( + get_stream_of_outbound_messages(var, 1) + ) @test getdata(message) === PointMass(fn(val1, val2)) end @@ -151,14 +176,17 @@ end true, true, fn, (val1, var2) ) activate!(var, options) - @test check_stream_not_updated(getmarginal(var)) + @test check_stream_not_updated(get_stream_of_marginals(var)) - marginal = check_stream_updated_once(getmarginal(var)) do - update!(var2, val2) - end + marginal = + check_stream_updated_once(get_stream_of_marginals(var)) do + new_observation!(var2, val2) + end @test getdata(marginal) === PointMass(fn(val1, val2)) - message = check_stream_updated_once(messageout(var, 1)) + message = check_stream_updated_once( + get_stream_of_outbound_messages(var, 1) + ) @test getdata(message) === PointMass(fn(val1, val2)) end @@ -179,15 +207,18 @@ end true, true, fn, (var1, var2) ) activate!(var, options) - @test check_stream_not_updated(getmarginal(var)) + @test check_stream_not_updated(get_stream_of_marginals(var)) - marginal = check_stream_updated_once(getmarginal(var)) do - update!(var1, val1) - update!(var2, val2) - end + marginal = + check_stream_updated_once(get_stream_of_marginals(var)) do + new_observation!(var1, val1) + new_observation!(var2, val2) + end @test getdata(marginal) === PointMass(fn(val1, val2)) - message = check_stream_updated_once(messageout(var, 1)) + message = check_stream_updated_once( + get_stream_of_outbound_messages(var, 1) + ) @test getdata(message) === PointMass(fn(val1, val2)) end @@ -208,12 +239,13 @@ end true, true, fn, (var1, var2) ) activate!(var, options) - @test check_stream_not_updated(getmarginal(var)) + @test check_stream_not_updated(get_stream_of_marginals(var)) # We still should be able to update the stream manually - marginal = check_stream_updated_once(getmarginal(var)) do - update!(var, 4) - end + marginal = + check_stream_updated_once(get_stream_of_marginals(var)) do + new_observation!(var, 4) + end @test getdata(marginal) === PointMass(4) end end diff --git a/test/variables/random_tests.jl b/test/variables/random_tests.jl index b298d62fe..c8ccfc6fb 100644 --- a/test/variables/random_tests.jl +++ b/test/variables/random_tests.jl @@ -1,25 +1,32 @@ @testitem "RandomVariable: uninitialized" begin - import ReactiveMP: messageout, messagein + import ReactiveMP: + get_stream_of_outbound_messages, get_stream_of_inbound_messages # Should throw if not initialised properly let var = randomvar() for i in 1:10 - @test_throws BoundsError messageout(var, i) - @test_throws BoundsError messagein(var, i) + @test_throws BoundsError get_stream_of_outbound_messages(var, i) + @test_throws BoundsError get_stream_of_inbound_messages(var, i) end end end -@testitem "RandomVariable: getmessagein!" begin - import ReactiveMP: MessageObservable, create_messagein!, messagein, degree +@testitem "RandomVariable: getget_stream_of_inbound_messages!" begin + import ReactiveMP: + MessageObservable, + create_new_stream_of_inbound_messages!, + get_stream_of_inbound_messages, + degree # Test for different degrees `d` for d in 1:5:100 let var = randomvar() for i in 1:d - messagein, index = create_messagein!(var) - @test messagein isa MessageObservable + new_stream_of_inbound_messages, index = create_new_stream_of_inbound_messages!( + var + ) + @test new_stream_of_inbound_messages isa MessageObservable @test index === i @test degree(var) === i end @@ -28,26 +35,30 @@ end end end -@testitem "RandomVariable: getmarginal" begin +@testitem "RandomVariable: get_stream_of_marginals" begin import ReactiveMP: MessageObservable, - create_messagein!, - messagein, + MessageProductContext, + create_new_stream_of_inbound_messages!, + compute_product_of_messages, + get_stream_of_inbound_messages, degree, activate!, connect!, RandomVariableActivationOptions, - messageout + get_stream_of_outbound_messages, + get_stream_of_marginals include("../testutilities.jl") - message_prod_fn = (msgs) -> error("Messages should not be called here") - marginal_prod_fn = (msgs) -> mgl(sum(getdata.(msgs))) + message_prod_fold = + (variable, context, msgs) -> error("Messages should not be called here") + marginal_prod_fold = (variable, context, msgs) -> msg(sum(getdata.(msgs))) for d in 1:5:100 let var = randomvar() - messageins = map(1:d) do _ + new_stream_of_inbound_messages = map(1:d) do _ s = Subject(AbstractMessage) - m, i = create_messagein!(var) + m, i = create_new_stream_of_inbound_messages!(var) connect!(m, s) return s end @@ -55,48 +66,58 @@ end activate!( var, RandomVariableActivationOptions( - AsapScheduler(), message_prod_fn, marginal_prod_fn + nothing, + MessageProductContext(; fold_strategy = message_prod_fold), + MessageProductContext(; fold_strategy = marginal_prod_fold), ), ) messages = map(msg, rand(d)) - marginal_expected = marginal_prod_fn(messages) - marginal_result = check_stream_updated_once(getmarginal(var)) do - foreach(zip(messageins, messages)) do (messagein, message) - next!(messagein, message) + marginal_expected = mgl(sum(getdata.(messages))) + marginal_result = + check_stream_updated_once(get_stream_of_marginals(var)) do + foreach( + zip(new_stream_of_inbound_messages, messages) + ) do (new_stream_of_inbound_messages, message) + next!(new_stream_of_inbound_messages, message) + end end - end - # We check the `getdata` here approximatelly because the `marginal_prod_fn` can rearrange + # We check the `getdata` here approximatelly because the `marginal_prod_fn` can rearrange # the messages under the hood that introduces minor numerical differences @test getdata(marginal_result) ≈ getdata(marginal_expected) end end end -@testitem "RandomVariable: messageout" begin +@testitem "RandomVariable: get_stream_of_outbound_messages" begin import ReactiveMP: MessageObservable, - create_messagein!, - messagein, + MessageProductContext, + create_new_stream_of_inbound_messages!, + compute_product_of_messages, + get_stream_of_inbound_messages, degree, activate!, connect!, RandomVariableActivationOptions, - messageout + get_stream_of_outbound_messages include("../testutilities.jl") - message_prod_fn = (msgs) -> msg(sum(filter(!ismissing, getdata.(msgs)))) - marginal_prod_fn = (msgs) -> error("Marginal should not be called here") + message_prod_fold = + (variable, context, msgs) -> + msg(sum(filter(!ismissing, getdata.(msgs)))) + marginal_prod_fold = + (variable, context, msgs) -> error("Marginal should not be called here") # We start from `2` because `1` is not a valid degree for a random variable for d in 2:5:100, k in 1:d let var = randomvar() - messageins = map(1:d) do _ + new_streams_of_inbound_messages = map(1:d) do _ s = Subject(AbstractMessage) - m, i = create_messagein!(var) + m, i = create_new_stream_of_inbound_messages!(var) connect!(m, s) return s end @@ -104,28 +125,176 @@ end activate!( var, RandomVariableActivationOptions( - AsapScheduler(), message_prod_fn, marginal_prod_fn + nothing, + MessageProductContext(; fold_strategy = message_prod_fold), + MessageProductContext(; fold_strategy = marginal_prod_fold), ), ) messages = map(msg, rand(d)) # the outbound message is the result of multiplication of `n - 1` messages excluding index `k` - kmessage_expected = message_prod_fn(collect(skipindex(messages, k))) - kmessage_result = check_stream_updated_once(messageout(var, k)) do - foreach(zip(messageins, messages)) do (messagein, message) - next!(messagein, message) + kmessage_expected = msg( + sum( + filter( + !ismissing, getdata.(collect(skipindex(messages, k))) + ), + ), + ) + kmessage_result = check_stream_updated_once( + get_stream_of_outbound_messages(var, k) + ) do + foreach( + zip(new_streams_of_inbound_messages, messages) + ) do (new_stream_of_inbound_messages, message) + next!(new_stream_of_inbound_messages, message) end end - # We check the `getdata` here approximatelly because the `message_prod_fn` can rearrange + # We check the `getdata` here approximatelly because the `message_prod_fn` can rearrange # the messages under the hood that introduces minor numerical differences @test getdata(kmessage_result) ≈ getdata(kmessage_expected) end end end +@testitem "RandomVariable: before/after marginal computation callbacks" begin + import ReactiveMP: + MessageObservable, + MessageProductContext, + RandomVariableActivationOptions, + AbstractMessage, + create_new_stream_of_inbound_messages!, + activate!, + connect!, + getdata, + get_stream_of_marginals + + import Rocket: Subject, next! + + include("../testutilities.jl") + + struct MarginalCallbackHandler + listen_to::Tuple + events + end + + function ReactiveMP.invoke_callback( + handler::MarginalCallbackHandler, event::ReactiveMP.Event{E} + ) where {E} + E ∈ handler.listen_to && + push!(handler.events, (event = E, data = event)) + end + + @testset "Fires before and after marginal computation with 3 messages" begin + listen_to = (:before_marginal_computation, :after_marginal_computation) + handler = MarginalCallbackHandler(listen_to, []) + marginal_context = MessageProductContext(; + fold_strategy = (variable, context, msgs) -> + msg(sum(getdata.(msgs))), + callbacks = handler, + ) + + var = randomvar() + + new_streams_of_inbounds_messages = map(1:3) do _ + s = Subject(AbstractMessage) + m, i = create_new_stream_of_inbound_messages!(var) + connect!(m, s) + return s + end + + activate!( + var, + RandomVariableActivationOptions( + nothing, MessageProductContext(), marginal_context + ), + ) + + messages = [msg(1.0), msg(2.0), msg(3.0)] + + marginal_result = + check_stream_updated_once(get_stream_of_marginals(var)) do + foreach( + zip(new_streams_of_inbounds_messages, messages) + ) do (new_stream_of_inbounds_messages, message) + next!(new_stream_of_inbounds_messages, message) + end + end + + # sum(1.0 + 2.0 + 3.0) = 6.0 + @test getdata(marginal_result) ≈ 6.0 + + @test length(handler.events) == 2 + + # Before: variable, context, messages + @test handler.events[1].event === :before_marginal_computation + @test handler.events[1].data.variable === var + @test handler.events[1].data.context === marginal_context + + # After: variable, context, messages, result + @test handler.events[2].event === :after_marginal_computation + @test handler.events[2].data.variable === var + @test handler.events[2].data.context === marginal_context + @test length(handler.events[2].data.messages) == 3 + @test getdata(handler.events[2].data.result) ≈ 6.0 + end + + @testset "Fires before and after marginal computation with 2 messages" begin + listen_to = (:before_marginal_computation, :after_marginal_computation) + handler = MarginalCallbackHandler(listen_to, []) + marginal_context = MessageProductContext(; + fold_strategy = (variable, context, msgs) -> + msg(sum(getdata.(msgs))), + callbacks = handler, + ) + + var = randomvar() + + new_streams_of_inbounds_messages = map(1:2) do _ + s = Subject(AbstractMessage) + m, i = create_new_stream_of_inbound_messages!(var) + connect!(m, s) + return s + end + + activate!( + var, + RandomVariableActivationOptions( + nothing, MessageProductContext(), marginal_context + ), + ) + + messages = [msg(10.0), msg(20.0)] + + marginal_result = + check_stream_updated_once(get_stream_of_marginals(var)) do + foreach( + zip(new_streams_of_inbounds_messages, messages) + ) do (new_stream_of_inbound_messages, message) + next!(new_stream_of_inbound_messages, message) + end + end + + # sum(10.0 + 20.0) = 30.0 + @test getdata(marginal_result) ≈ 30.0 + + @test length(handler.events) == 2 + + @test handler.events[1].event === :before_marginal_computation + @test handler.events[1].data.variable === var + + @test handler.events[2].event === :after_marginal_computation + @test handler.events[2].data.variable === var + @test length(handler.events[2].data.messages) == 2 + @test getdata(handler.events[2].data.result) ≈ 30.0 + end +end + @testitem "RandomVariable: activate! - zero or less than one inbound messages should throw" begin - import ReactiveMP: RandomVariableActivationOptions, activate!, messageout + import ReactiveMP: + RandomVariableActivationOptions, + activate!, + get_stream_of_outbound_messages let var = randomvar() @test_throws "Cannot activate a random variable with zero or less than one inbound messages." activate!( diff --git a/test/variables/variable_tests.jl b/test/variables/variable_tests.jl index fcac1253d..dcb586c63 100644 --- a/test/variables/variable_tests.jl +++ b/test/variables/variable_tests.jl @@ -2,8 +2,12 @@ @testitem "Variable" begin using ReactiveMP, Rocket, BayesBase, Distributions, ExponentialFamily - import ReactiveMP: activate! - import Rocket: getscheduler + import ReactiveMP: + activate!, + get_stream_of_marginals, + degree, + set_initial_marginal!, + set_initial_message! struct CustomDeterministicNodeForVariableTests end @@ -26,15 +30,15 @@ @test degree(variable) === k - # Check that before calling the `setmarginals!` all marginals are `nothing` - @test isnothing(Rocket.getrecent(getmarginal(variable, IncludeAll()))) + # Check that before calling the `set_initial_marginal!` all marginals are `nothing` + @test isnothing(Rocket.getrecent(get_stream_of_marginals(variable))) - setmarginal!(variable, dist) + set_initial_marginal!(variable, dist) marginal_subscription_flag = false - # After calling the `setmarginals!` the marginal should be equal to `dist` + # After calling the `set_initial_marginal!` the marginal should be equal to `dist` subscription = subscribe!( - getmarginal(variable, IncludeAll()), + get_stream_of_marginals(variable), (marginal) -> begin @test typeof(marginal) <: Marginal{T} @test mean(marginal) === mean(dist) @@ -45,21 +49,25 @@ @test marginal_subscription_flag === true unsubscribe!(subscription) - # Check that before calling the `setmessages!` all messages are `nothing` + # Check that before calling the `set_initial_message!` all messages are `nothing` for node_index in 1:k @test isnothing( - Rocket.getrecent(ReactiveMP.messageout(variable, node_index)) + Rocket.getrecent( + ReactiveMP.get_stream_of_outbound_messages( + variable, node_index + ), + ), ) end - for node_index in 1:k - setmessage!(variable, node_index, dist) - end + set_initial_message!(variable, dist) for node_index in 1:k message_subscription_flag = false subscription = subscribe!( - ReactiveMP.messageout(variable, node_index), + ReactiveMP.get_stream_of_outbound_messages( + variable, node_index + ), (message) -> begin @test typeof(message) <: Message{T} @test mean(message) === mean(dist) @@ -75,7 +83,7 @@ function test_variables_set_methods(variables, dist::T, k::Int) where {T} marginal_subscription_flag = false - @test_throws AssertionError setmarginals!( + @test_throws AssertionError set_initial_marginal!( variables, Iterators.repeated(dist, length(variables) - 1) ) @@ -95,25 +103,25 @@ @test all(degree.(variables) .== k) - @test_throws AssertionError setmessages!( + @test_throws AssertionError set_initial_message!( variables, Iterators.repeated(dist, length(variables) - 1) ) - @test_throws AssertionError setmessages!( + @test_throws AssertionError set_initial_message!( variables, Iterators.repeated(dist, length(variables) - 1) ) - # Test `setmarginals!` + # Test `set_initial_marginal!` - # Check that before calling the `setmarginals!` all marginals are `nothing` + # Check that before calling the `set_initial_marginal!` all marginals are `nothing` @test all( - isnothing, Rocket.getrecent.(getmarginal.(variables, IncludeAll())) + isnothing, Rocket.getrecent.(get_stream_of_marginals.(variables)) ) - setmarginals!(variables, dist) + set_initial_marginal!(variables, dist) - # After calling the `setmarginals!` all marginals should be equal to `dist` + # After calling the `set_initial_marginal!` all marginals should be equal to `dist` subscription = subscribe!( - getmarginals(variables, IncludeAll()), + collectLatest(map(get_stream_of_marginals, variables)), (marginals) -> begin @test length(marginals) === length(variables) foreach(marginals) do marginal @@ -129,23 +137,29 @@ @test marginal_subscription_flag === true unsubscribe!(subscription) - # Check that before calling the `setmessages!` all messages are `nothing` + # Check that before calling the `set_initial_message!` all messages are `nothing` for node_index in 1:k @test all( isnothing, Rocket.getrecent.( - ReactiveMP.messageout.(variables, node_index) + ReactiveMP.get_stream_of_outbound_messages.( + variables, node_index + ), ), ) end - # After calling the `setmessages!` all marginals should be equal to `dist` - setmessages!(variables, dist) + # After calling the `set_initial_message!` all marginals should be equal to `dist` + set_initial_message!(variables, dist) # For each outbound index for node_index in 1:k messages_subscription_flag = false subscription = subscribe!( - collectLatest(ReactiveMP.messageout.(variables, node_index)), + collectLatest( + ReactiveMP.get_stream_of_outbound_messages.( + variables, node_index + ), + ), (messages) -> begin @test length(messages) === length(variables) foreach(messages) do message @@ -161,7 +175,7 @@ end end - @testset "setmarginal! and setmessages! tests for randomvar" begin + @testset "set_initial_marginal! and set_initial_message! tests for randomvar" begin dists = ( NormalMeanVariance(-2.0, 3.0), NormalMeanPrecision(-2.0, 3.0),