-
Notifications
You must be signed in to change notification settings - Fork 11
Add Evaluators: prepare interface, and vectorisation utilities #157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 15 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
3145f5d
Add Evaluators module: prepare/Prepared interface, evaluator shapes, …
yebai 0a3a935
Format with JuliaFormatter (blue style)
yebai 6fba5b1
Fix docs example dispatch; rename test fixtures away from Problem
yebai d632db5
format
yebai 154b009
Make Prepared strictly internal to AbstractPPL.Evaluators
yebai 0db151a
Restrict LDP integration to vector-input evaluators; default capabili…
yebai d0ebb6c
Preserve leaf types in unflatten_to!! round-trip
yebai 73c97e9
Reframe check_dims rationale around the outer entry point
yebai c2b84b2
Clarify check_dims trust model and no-AD interface rationale
yebai 3ad1396
Document why the AD output-shape assertions live here
yebai 6c9e4f2
Address remaining PR review feedback
yebai 88f1661
Note TODO for proper structured-array / Cholesky support
yebai a8ff7e1
Address late-round review feedback on utils.jl
yebai 0ad98f5
Validate nested-array shapes and supported leaves in NamedTupleEvaluator
yebai 2eecfa9
Switch flatten/unflatten to opt-in array allow-list
yebai 49b20aa
Address final round of PR review feedback
yebai 7ac3e3a
Simplify Evaluators module after review
yebai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,10 @@ | ||
| [deps] | ||
| ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
| AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" | ||
| Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
| Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
| ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
| LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
|
|
||
| [sources] | ||
| AbstractPPL = {path = "../"} | ||
| AbstractPPL = {path = ".."} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| # Evaluator preparation and AD | ||
|
|
||
| AbstractPPL provides a small interface for preparing callables and asking a | ||
| prepared evaluator for values and derivatives. `prepare` binds a callable to a | ||
| sample input that establishes the expected input shape and type; | ||
| `value_and_gradient!!` and `value_and_jacobian!!` then return the value and | ||
| derivative together. | ||
|
|
||
| The `!!` suffix signals that the returned gradient or Jacobian **may alias | ||
| internal cache buffers** of the prepared evaluator. The next call to | ||
| `value_and_gradient!!` (or `value_and_jacobian!!`) may overwrite that buffer | ||
| in place, so a previously-returned reference will silently change. Copy | ||
| before holding on to a result: | ||
|
|
||
| ```julia | ||
| val, grad = value_and_gradient!!(prepared, x1) | ||
| saved = copy(grad) # safe to keep | ||
| val2, grad2 = value_and_gradient!!(prepared, x2) | ||
| # `grad` may now reflect `x2`; `saved` still reflects `x1` | ||
| ``` | ||
|
|
||
| Backends that always allocate fresh output (e.g. `ForwardDiff.gradient`) do | ||
| not actually alias, but consumers should not rely on that — write to the | ||
| contract, not the implementation. | ||
|
|
||
| ## Quick start | ||
|
|
||
| ```@example ad | ||
| using AbstractPPL | ||
| using AbstractPPL: prepare, value_and_gradient!! | ||
| using AbstractPPL.Evaluators: Prepared, VectorEvaluator, NamedTupleEvaluator | ||
| using ADTypes: AutoForwardDiff | ||
| using ForwardDiff: ForwardDiff | ||
|
|
||
| function AbstractPPL.prepare(adtype::AutoForwardDiff, f, x::AbstractVector{<:Real}) | ||
| return Prepared(adtype, VectorEvaluator(f, length(x))) | ||
| end | ||
|
|
||
| function AbstractPPL.value_and_gradient!!( | ||
| p::Prepared{<:AutoForwardDiff}, x::AbstractVector{<:Real} | ||
| ) | ||
| return (p(x), ForwardDiff.gradient(p.evaluator.f, x)) | ||
| end | ||
|
|
||
| mvnormal_logp(x) = -0.5 * sum(abs2, x) # standard normal log density (up to constant) | ||
| prepared = prepare(AutoForwardDiff(), mvnormal_logp, zeros(3)) | ||
| value_and_gradient!!(prepared, [1.0, 2.0, 3.0]) | ||
| ``` | ||
|
|
||
| ## Two input styles | ||
|
|
||
| ### Vector inputs | ||
|
|
||
| When the callable accepts a flat vector, pass a sample vector whose length | ||
| matches the expected input: | ||
|
|
||
| ```@example ad | ||
| prepared([1.0, 2.0, 3.0]) | ||
| ``` | ||
|
|
||
| For vector-valued callables, use `value_and_jacobian!!`. The returned Jacobian | ||
| has shape `(length(value), length(x))`. The same backend extension that | ||
| defines `value_and_gradient!!` typically also defines `value_and_jacobian!!` | ||
| on the same `Prepared` type — they are separate generic functions, so the | ||
| two methods coexist without conflict and the caller picks whichever applies | ||
| to their function: | ||
|
|
||
| ```@example ad | ||
| using AbstractPPL: value_and_jacobian!! | ||
|
|
||
| function AbstractPPL.value_and_jacobian!!( | ||
| p::Prepared{<:AutoForwardDiff}, x::AbstractVector{<:Real} | ||
| ) | ||
| return (p(x), ForwardDiff.jacobian(p.evaluator.f, x)) | ||
| end | ||
|
|
||
| vecfun(x) = [x[1] * x[2], x[2] + x[3]] | ||
| prepared_vec = prepare(AutoForwardDiff(), vecfun, zeros(3)) | ||
| value_and_jacobian!!(prepared_vec, [2.0, 3.0, 4.0]) | ||
| ``` | ||
|
|
||
| ### NamedTuple inputs | ||
|
|
||
| When the callable accepts a `NamedTuple`, pass a sample `NamedTuple` whose | ||
| field names and value types match the expected input. The prototype's leaves | ||
| must be `Real`, `Complex`, `AbstractArray` (recursively), `Tuple`, or | ||
| `NamedTuple`. An extension can define a `prepare` overload that wraps the | ||
| function in a `NamedTupleEvaluator`: | ||
|
|
||
| ```@example ad | ||
| function AbstractPPL.prepare(adtype::AutoForwardDiff, f, values::NamedTuple) | ||
| return Prepared(adtype, NamedTupleEvaluator(f, values)) | ||
| end | ||
|
|
||
| ntfun(v::NamedTuple) = v.a^2 + sum(abs2, v.b) | ||
| prepared_nt = prepare(AutoForwardDiff(), ntfun, (a=0.0, b=zeros(2))) | ||
| prepared_nt((a=1.0, b=[2.0, 3.0])) | ||
| ``` | ||
|
|
||
| ## AD backends | ||
|
|
||
| Automatic differentiation packages extend the interface by implementing | ||
| `value_and_gradient!!` and `value_and_jacobian!!` for specific cache types | ||
| stored in `prepared.cache`: | ||
|
|
||
| ```julia | ||
| prepared = prepare(adtype, problem, prototype) # returns Prepared{AD,E,Cache} | ||
| value_and_gradient!!(prepared, x) # may return aliased cache buffer | ||
| value_and_jacobian!!(prepared, x) | ||
| ``` | ||
|
|
||
| `Prepared` has three fields: `adtype`, `evaluator` (the user-facing callable), | ||
| and `cache` (backend-specific pre-allocated state such as ForwardDiff configs or | ||
| Mooncake tapes). Backend extensions dispatch on the cache type: | ||
|
|
||
| ```julia | ||
| function AbstractPPL.prepare( | ||
| adtype::MyADType, problem, x::AbstractVector{<:Real}; check_dims::Bool=true | ||
| ) | ||
| f = ... # extract callable from problem | ||
| cache = MyCache(f, x) | ||
| return Prepared(adtype, VectorEvaluator{check_dims}(f, length(x)), cache) | ||
| end | ||
|
|
||
| function AbstractPPL.value_and_gradient!!( | ||
| p::Prepared{<:AbstractADType,<:VectorEvaluator,<:MyCache}, x::AbstractVector{<:Real} | ||
| ) | ||
| # use p.cache to avoid allocations | ||
| return ... | ||
| end | ||
| ``` | ||
|
|
||
| Pass `check_dims=false` in your `prepare` implementation to construct a | ||
| `VectorEvaluator{false}`, which skips the per-call length check. This is an | ||
| opt-in trust mode — the caller takes responsibility for `length(x)`. The | ||
| typical use is inside a backend's `value_and_gradient!!`, where the AD | ||
| library invokes the inner callable many times with same-length dual arrays | ||
| derived from a single user-supplied `x`; re-validating on each invocation | ||
| would be redundant work in the hot path. | ||
|
|
||
| ## Without an AD backend | ||
|
|
||
| The two-argument form `prepare(problem, x)` is available without any AD | ||
| package. By default it wraps `problem` in a `VectorEvaluator{check_dims}` | ||
| (or `NamedTupleEvaluator{check_dims}` for the `NamedTuple` form), giving you | ||
| a callable that runs the per-call shape check before forwarding to | ||
| `problem`. Downstream code that only needs primal evaluation (e.g. | ||
| log-density only, no gradient) can call `prepare(...)` uniformly without | ||
| knowing whether an AD backend is loaded: | ||
|
|
||
| ```@example ad | ||
| sumsimple(x) = sum(x) | ||
| p = prepare(sumsimple, zeros(3)) # `VectorEvaluator{true}(sumsimple, 3)` | ||
| p([1.0, 2.0, 3.0]) | ||
| ``` | ||
|
|
||
| ## API reference | ||
|
|
||
| ```@docs | ||
| AbstractPPL.prepare | ||
| AbstractPPL.value_and_gradient!! | ||
| AbstractPPL.value_and_jacobian!! | ||
| ``` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| module AbstractPPLLogDensityProblemsExt | ||
|
|
||
| using AbstractPPL: AbstractPPL | ||
| using AbstractPPL.Evaluators: Prepared, VectorEvaluator | ||
| using LogDensityProblems: LogDensityProblems | ||
|
|
||
| # LDP integration is restricted to vector-input evaluators; `NamedTupleEvaluator` | ||
| # does not satisfy LDP's vector-input contract. Scalar output is a runtime | ||
| # contract the user must satisfy. | ||
|
|
||
| LogDensityProblems.logdensity(p::Prepared{<:Any,<:VectorEvaluator}, x) = p(x) | ||
| LogDensityProblems.logdensity(e::VectorEvaluator, x) = e(x) | ||
|
|
||
| function LogDensityProblems.dimension(p::Prepared{<:Any,<:VectorEvaluator}) | ||
| return LogDensityProblems.dimension(p.evaluator) | ||
| end | ||
| LogDensityProblems.dimension(e::VectorEvaluator) = e.dim | ||
|
|
||
| # Generic fallback: order 0. AD-backend extensions (DifferentiationInterface, | ||
| # ForwardDiff, Mooncake, etc.) must overload this for their cache type to | ||
| # advertise `LogDensityOrder{1}` — without that overload, | ||
| # `logdensity_and_gradient` will hit the `value_and_gradient!!` stub and fail. | ||
| function LogDensityProblems.capabilities(::Type{<:Prepared{<:Any,<:VectorEvaluator}}) | ||
| return LogDensityProblems.LogDensityOrder{0}() | ||
| end | ||
| function LogDensityProblems.capabilities(p::Prepared{<:Any,<:VectorEvaluator}) | ||
| return LogDensityProblems.capabilities(typeof(p)) | ||
| end | ||
|
|
||
| function LogDensityProblems.capabilities(::Type{<:VectorEvaluator}) | ||
| return LogDensityProblems.LogDensityOrder{0}() | ||
| end | ||
| function LogDensityProblems.capabilities(e::VectorEvaluator) | ||
| return LogDensityProblems.capabilities(typeof(e)) | ||
| end | ||
|
|
||
| function LogDensityProblems.logdensity_and_gradient(p::Prepared{<:Any,<:VectorEvaluator}, x) | ||
| val, grad = AbstractPPL.value_and_gradient!!(p, x) | ||
| return (val, copy(grad)) | ||
| end | ||
|
|
||
| end # module |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.