Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/ExponentialFamilyProjection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ Checks the compatibility of `strategy` with `argument` and returns a modified st
"""
function preprocess_strategy_argument end

"""
check_compatibility(projection_argument, manifold, prj::ProjectedTo)

An optional interface for validating compatibility between a projection argument and the target manifold.

This function can be implemented by users to perform custom validation checks before projection.
By default, it does nothing, but users can override it to check:
- Dimensionality compatibility between the projection argument and target distribution
- Sample validity for the target distribution type
- Other domain-specific constraints

See the documentation in `projected_to.jl` for more details and examples.
"""
function check_compatibility end

"""
create_state!(
strategy,
Expand Down
73 changes: 66 additions & 7 deletions src/projected_to.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,70 @@ function check_inputs(prj::ProjectedTo, projection_argument::F, supplementary...
lazy"The initial point must be on the manifold `$(get_projected_to_manifold(prj))`, got `$(typeof(initialpoint))`",
)
end
end

"""
check_compatibility(projection_argument, manifold, prj::ProjectedTo)

An optional interface for validating compatibility between a projection argument and the target manifold.

This function can be implemented by users to perform custom validation checks before projection.
By default, it does nothing, but users can override it to check:
- Dimensionality compatibility between the projection argument and target distribution
- Sample validity for the target distribution type
- Other domain-specific constraints

# Arguments
- `projection_argument`: The argument being projected (typically a function or samples)
- `manifold`: The target manifold for the projection
- `prj::ProjectedTo`: The projection configuration

# Examples
```julia
# Example 1: Check dimensionality for a custom logpdf type
struct MyInplaceLogpdf
logpdf!::Function
end

function ExponentialFamilyProjection.check_compatibility(
arg::MyInplaceLogpdf,
manifold::AbstractManifold,
prj::ProjectedTo
)
dims = ExponentialFamilyProjection.get_projected_to_dims(prj)
if !isempty(dims)
# Test with a sample input of the expected dimensions
test_sample = zeros(dims[1])
test_output = [0.0]
try
arg.logpdf!(test_output, test_sample)
catch e
error("Dimensionality mismatch: projection dimensions \$(dims) may be incompatible with logpdf! function")
end
end
end

# Example 2: Validate sample array dimensions
function ExponentialFamilyProjection.check_compatibility(
samples::AbstractArray,
manifold::AbstractManifold,
prj::ProjectedTo
)
dims = ExponentialFamilyProjection.get_projected_to_dims(prj)
if !isempty(dims) && size(samples, 1) != dims[1]
error("Sample dimension \$(size(samples, 1)) does not match projection dimension \$(dims[1])")
end
end
```

!!! note
This function is called before the projection optimization starts. By default,
it performs no checks and returns immediately. Users are encouraged to implement
this method for their custom types when appropriate validation is needed.
"""
function check_compatibility(projection_argument, manifold, prj::ProjectedTo)
# Default implementation: do nothing
return nothing
end
"""
project_to(to::ProjectedTo, argument::F, supplementary..., initialpoint, kwargs...)
Expand Down Expand Up @@ -290,13 +354,8 @@ function project_to(
return copy(getnaturalparameters(supplementary_ef))
end

try
projection_argument.logpdf!([0.0], randn(prj.dims))
catch e
error(
"The supplied projection dimensions `$(prj.dims)` may be invalid for the provided logpdf! function. Check dimensions and logpdf! function.\n",
)
end
# Allow users to implement custom compatibility checks for their projection arguments
check_compatibility(projection_argument, M, prj)

strategy, projection_argument = preprocess_strategy_argument(
getstrategy(projection_parameters),
Expand Down