From 5a805a3a55c55a06bee005d81ec3e687ce7421b6 Mon Sep 17 00:00:00 2001 From: Tim Date: Sat, 8 Nov 2025 05:44:14 +0100 Subject: [PATCH] #67 Add formatter. make format. --- .gitignore | 3 +- Makefile | 16 + docs/make.jl | 2 +- scripts/Project.toml | 5 + scripts/formatter.jl | 26 + src/jacobians.jl | 84 ++- src/manopt/bounded_norm_update_rule.jl | 6 +- src/manopt/projection_objective.jl | 9 +- src/projected_to.jl | 23 +- src/strategies/bonnet/gauss_newton.jl | 52 +- src/strategies/bonnet/naive_grad_hess.jl | 2 +- src/strategies/bonnet/strategy.jl | 81 ++- src/strategies/control_variate.jl | 12 +- src/strategies/default.jl | 2 +- src/strategies/mle.jl | 16 +- test/aqua_tests.jl | 8 +- test/manopt/bounded_norm_update_rule_tests.jl | 2 +- test/manopt_setuptests.jl | 2 +- test/projection/batch_logpdf.jl | 18 +- test/projection/helpers/debug.jl | 4 +- .../projected_to_bernoulli_tests.jl | 18 +- test/projection/projected_to_beta_tests.jl | 2 +- .../projection/projected_to_binomial_tests.jl | 16 +- test/projection/projected_to_chisq_tests.jl | 2 +- .../projected_to_dirichlet_tests.jl | 22 +- .../projected_to_exponential_tests.jl | 2 +- .../projected_to_geometric_tests.jl | 2 +- .../projected_to_inverse_gamma_tests.jl | 4 +- test/projection/projected_to_laplace_tests.jl | 2 +- .../projected_to_lognormal_tests.jl | 2 +- test/projection/projected_to_normal_tests.jl | 34 +- test/projection/projected_to_poisson_tests.jl | 8 +- .../projection/projected_to_rayleigh_tests.jl | 17 +- test/projection/projected_to_setuptests.jl | 33 +- test/projection/projected_to_tests.jl | 84 +-- test/projection/projected_to_weibull_tests.jl | 40 +- test/runtests.jl | 21 +- test/strategies/bonnet_tests.jl | 589 +++++++++++------- test/strategies/control_variate_tests.jl | 20 +- test/strategies/default_tests.jl | 2 +- test/strategies/gauss_newton_tests.jl | 66 +- test/strategies/mle_tests.jl | 15 +- 42 files changed, 860 insertions(+), 514 deletions(-) create mode 100644 scripts/Project.toml create mode 100644 scripts/formatter.jl diff --git a/.gitignore b/.gitignore index d6b3e82..4b60e9f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ alloc-profile.pb.gz profile.pb.gz *.ipynb -.ipynb_checkpoints \ No newline at end of file +.ipynb_checkpoints +*Manifest.toml \ No newline at end of file diff --git a/Makefile b/Makefile index 4c26e87..e547e03 100644 --- a/Makefile +++ b/Makefile @@ -4,9 +4,13 @@ DOCSRC = docs DOCTARGET = $(DOCSRC)/build +SCRIPTSRC = scripts +FORMATTER = $(SCRIPTSRC)/formatter.jl + JULIA ?= julia JULIAFLAGS ?= --project=. JULIAFLAGSDOCS ?= --project=$(DOCSRC) +JULIAFLAGSSCRIPTS ?= --project=$(SCRIPTSRC) # Colors for terminal output ifdef NO_COLOR @@ -38,7 +42,10 @@ help: @echo '${GREEN}Development commands:${RESET}' @echo ' ${YELLOW}deps${RESET} Install project dependencies' @echo ' ${YELLOW}deps-docs${RESET} Install documentation dependencies' + @echo ' ${YELLOW}deps-scripts${RESET} Install script dependencies' @echo ' ${YELLOW}test${RESET} Run project tests' + @echo ' ${YELLOW}format${RESET} Format Julia code' + @echo ' ${YELLOW}check-format${RESET} Check Julia code formatting (does not modify files)' @echo ' ${YELLOW}clean${RESET} Clean all generated files' @echo '' @echo '${GREEN}Help:${RESET}' @@ -68,8 +75,17 @@ deps: ## Install project dependencies deps-docs: ## Install documentation dependencies $(JULIA) $(JULIAFLAGSDOCS) -e 'using Pkg; Pkg.develop(path="."); Pkg.instantiate()' +deps-scripts: ## Install script dependencies + $(JULIA) $(JULIAFLAGSSCRIPTS) -e 'using Pkg; Pkg.instantiate()' + test: deps ## Run project tests $(JULIA) $(JULIAFLAGS) -e 'using Pkg; Pkg.test(test_args = split("$(test_args)") .|> string)' +format: deps-scripts ## Format Julia code + $(JULIA) $(JULIAFLAGSSCRIPTS) $(FORMATTER) --overwrite + +check-format: deps-scripts ## Check Julia code formatting (does not modify files) + $(JULIA) $(JULIAFLAGSSCRIPTS) $(FORMATTER) + clean: docs-clean ## Clean all generated files rm -rf .julia/compiled diff --git a/docs/make.jl b/docs/make.jl index 4855ce8..03f76f4 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,7 +16,7 @@ makedocs(; canonical = "https://reactivebayes.github.io/ExponentialFamilyProjection.jl", edit_link = "main", assets = String[], - repolink="https://github.com/ReactiveBayes/ExponentialFamilyProjection.jl" + repolink = "https://github.com/ReactiveBayes/ExponentialFamilyProjection.jl", ), pages = ["Home" => "index.md"], ) diff --git a/scripts/Project.toml b/scripts/Project.toml new file mode 100644 index 0000000..7dc6125 --- /dev/null +++ b/scripts/Project.toml @@ -0,0 +1,5 @@ +[deps] +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" diff --git a/scripts/formatter.jl b/scripts/formatter.jl new file mode 100644 index 0000000..cf86c5b --- /dev/null +++ b/scripts/formatter.jl @@ -0,0 +1,26 @@ +using JuliaFormatter +using ArgParse + +s = ArgParseSettings() + +@add_arg_table s begin + "--overwrite" + help = "Overwrite the files with the formatted code" + action = :store_true + default = false +end + +args = parse_args(ARGS, s) +overwrite = args["overwrite"] +projectroot = joinpath(@__DIR__, "..") + +passed = format(projectroot; verbose = true, overwrite = overwrite) + +if !passed && !overwrite + @error "JuliaFormatter check has failed. Run `make format` from the main directory and commit your changes to fix code style." + exit(1) +elseif !passed && overwrite + @info "JuliaFormatter has overwritten files according to style guidelines" +elseif passed + @info "Codestyle from JuliaFormatted checks have passed" +end diff --git a/src/jacobians.jl b/src/jacobians.jl index 9618229..e5eb28b 100644 --- a/src/jacobians.jl +++ b/src/jacobians.jl @@ -3,65 +3,121 @@ function jacobian_nat_to_manifold!(::AbstractManifold, X_p, X_nat) return X_p end -function jacobian_nat_to_manifold!(::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.NormalMeanVariance}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + ::ExponentialFamilyManifolds.NaturalParametersManifold{ + F, + ExponentialFamily.NormalMeanVariance, + }, + X_p, + X_nat, +) where {F} X_p[1:1] .= X_nat[1] X_p[2:2] .= -X_nat[2] return X_p end -function jacobian_nat_to_manifold!(::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.Gamma}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + ::ExponentialFamilyManifolds.NaturalParametersManifold{F,ExponentialFamily.Gamma}, + X_p, + X_nat, +) where {F} X_p[1:1] .= X_nat[1] X_p[2:2] .= -X_nat[2] return X_p end -function jacobian_nat_to_manifold!(::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.Rayleigh}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + ::ExponentialFamilyManifolds.NaturalParametersManifold{F,ExponentialFamily.Rayleigh}, + X_p, + X_nat, +) where {F} X_p[1:1] .= -X_nat[1] return X_p end -function jacobian_nat_to_manifold!(::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.Geometric}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + ::ExponentialFamilyManifolds.NaturalParametersManifold{F,ExponentialFamily.Geometric}, + X_p, + X_nat, +) where {F} X_p[1:1] .= -X_nat[1] return X_p end -function jacobian_nat_to_manifold!(::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.GammaInverse}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + ::ExponentialFamilyManifolds.NaturalParametersManifold{ + F, + ExponentialFamily.GammaInverse, + }, + X_p, + X_nat, +) where {F} X_p[1:1] .= -X_nat[1] X_p[2:2] .= -X_nat[2] return X_p end -function jacobian_nat_to_manifold!(::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.Exponential}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + ::ExponentialFamilyManifolds.NaturalParametersManifold{F,ExponentialFamily.Exponential}, + X_p, + X_nat, +) where {F} X_p[1:1] .= -X_nat[1] return X_p end -function jacobian_nat_to_manifold!(::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.Weibull}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + ::ExponentialFamilyManifolds.NaturalParametersManifold{F,ExponentialFamily.Weibull}, + X_p, + X_nat, +) where {F} X_p[1:1] .= -X_nat[1] return X_p end -function jacobian_nat_to_manifold!(::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.Laplace}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + ::ExponentialFamilyManifolds.NaturalParametersManifold{F,ExponentialFamily.Laplace}, + X_p, + X_nat, +) where {F} X_p[1:1] .= -X_nat[1] return X_p end -function jacobian_nat_to_manifold!(::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.LogNormal}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + ::ExponentialFamilyManifolds.NaturalParametersManifold{F,ExponentialFamily.LogNormal}, + X_p, + X_nat, +) where {F} X_p[1:1] .= X_nat[1] X_p[2:2] .= -X_nat[2] return X_p end -function jacobian_nat_to_manifold!(M::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.MvNormalMeanScalePrecision}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + M::ExponentialFamilyManifolds.NaturalParametersManifold{ + F, + ExponentialFamily.MvNormalMeanScalePrecision, + }, + X_p, + X_nat, +) where {F} k = first(ExponentialFamilyManifolds.getdims(M)) X_p[1:k] .= X_nat[1:k] - X_p[k+1:k+1] .= -X_nat[k+1:k+1] + X_p[(k+1):(k+1)] .= -X_nat[(k+1):(k+1)] return X_p end -function jacobian_nat_to_manifold!(M::ExponentialFamilyManifolds.NaturalParametersManifold{F, ExponentialFamily.MvNormalMeanCovariance}, X_p, X_nat) where {F} +function jacobian_nat_to_manifold!( + M::ExponentialFamilyManifolds.NaturalParametersManifold{ + F, + ExponentialFamily.MvNormalMeanCovariance, + }, + X_p, + X_nat, +) where {F} k = first(ExponentialFamilyManifolds.getdims(M)) X_p[1:k] .= X_nat[1:k] - X_p[(k + 1):(k + k^2)] .= -X_nat[(k + 1):(k + k^2)] + X_p[(k+1):(k+k^2)] .= -X_nat[(k+1):(k+k^2)] return X_p -end \ No newline at end of file +end diff --git a/src/manopt/bounded_norm_update_rule.jl b/src/manopt/bounded_norm_update_rule.jl index 90fb385..fa41b85 100644 --- a/src/manopt/bounded_norm_update_rule.jl +++ b/src/manopt/bounded_norm_update_rule.jl @@ -29,7 +29,10 @@ function init_direction_rule(d::BoundedNormUpdateRule, ::Any) return d end -function init_direction_rule(bounded_direction::BoundedNormUpdateRule{L,D}, M) where {L, D <: Manopt.ManifoldDefaultsFactory} +function init_direction_rule( + bounded_direction::BoundedNormUpdateRule{L,D}, + M, +) where {L,D<:Manopt.ManifoldDefaultsFactory} inner_direction = bounded_direction.direction(M) return BoundedNormUpdateRule(bounded_direction.limit, inner_direction) end @@ -54,4 +57,3 @@ function (b::BoundedNormUpdateRule)( end return step, d end - diff --git a/src/manopt/projection_objective.jl b/src/manopt/projection_objective.jl index 91f7f3c..0db69d4 100644 --- a/src/manopt/projection_objective.jl +++ b/src/manopt/projection_objective.jl @@ -31,7 +31,12 @@ get_supplementary_η(obj::ProjectionCostGradientObjective) = obj.supplementary_ get_strategy(obj::ProjectionCostGradientObjective) = obj.strategy get_strategy_state(obj::ProjectionCostGradientObjective) = obj.strategy_state -function call_objective(objective::ProjectionCostGradientObjective, M::AbstractManifold, X, p) +function call_objective( + objective::ProjectionCostGradientObjective, + M::AbstractManifold, + X, + p, +) current_ef = convert(ExponentialFamilyDistribution, M, p) current_η = copyto!(get_current_η(objective), getnaturalparameters(current_ef)) @@ -89,5 +94,3 @@ end function (objective::ProjectionCostGradientObjective)(M::AbstractManifold, X, p) return call_objective(objective, M, X, p) end - - diff --git a/src/projected_to.jl b/src/projected_to.jl index 9a07487..63d0760 100644 --- a/src/projected_to.jl +++ b/src/projected_to.jl @@ -205,7 +205,13 @@ end using Manopt -function check_inputs(prj::ProjectedTo, projection_argument::F, supplementary...; initialpoint = nothing, kwargs...) where {F} +function check_inputs( + prj::ProjectedTo, + projection_argument::F, + supplementary...; + initialpoint = nothing, + kwargs..., +) where {F} if isnothing(initialpoint) return end @@ -214,7 +220,7 @@ function check_inputs(prj::ProjectedTo, projection_argument::F, supplementary... lazy"The initial point must be on the manifold `$(get_projected_to_manifold(prj))`, got `$(typeof(initialpoint))`", ) end -end +end """ project_to(to::ProjectedTo, argument::F, supplementary..., initialpoint, kwargs...) @@ -294,8 +300,15 @@ function project_to( getstrategy(projection_parameters), projection_argument, ) - current_iteration_point = preprocess_initialpoint(initialpoint, strategy, M, projection_parameters) - check_inputs(prj, projection_argument, supplementary...; initialpoint = current_iteration_point, kwargs...) + current_iteration_point = + preprocess_initialpoint(initialpoint, strategy, M, projection_parameters) + check_inputs( + prj, + projection_argument, + supplementary...; + initialpoint = current_iteration_point, + kwargs..., + ) current_ef = convert(ExponentialFamilyDistribution, M, current_iteration_point) state = create_state!( strategy, @@ -424,4 +437,4 @@ end # Otherwise we just copy the initial point, since we use it for the optimization in place function preprocess_initialpoint(_, initialpoint::AbstractArray, strategy, M, parameters) return copy(initialpoint) -end \ No newline at end of file +end diff --git a/src/strategies/bonnet/gauss_newton.jl b/src/strategies/bonnet/gauss_newton.jl index 288fa79..f1da51a 100644 --- a/src/strategies/bonnet/gauss_newton.jl +++ b/src/strategies/bonnet/gauss_newton.jl @@ -23,7 +23,7 @@ preprocess_strategy_argument(::GaussNewton, argument::AbstractArray) = error( lazy"The `GaussNewton` strategy requires the projection argument to be a callable object (e.g. `Function`) or an `InplaceLogpdfGradHess`. Got `$(typeof(argument))` instead.", ) -Base.@kwdef struct GaussNewtonState{S, L, LB, G, H, M} +Base.@kwdef struct GaussNewtonState{S,L,LB,G,H,M} samples::S logpdfs::L logbasemeasures::LB @@ -52,7 +52,12 @@ function create_state!( state = GaussNewtonState( samples = prepare_samples_container(rng, initial_ef, nsamples, supplementary_η), logpdfs = prepare_logpdfs_container(rng, initial_ef, nsamples, supplementary_η), - logbasemeasures = prepare_logbasemeasures_container(rng, initial_ef, nsamples, supplementary_η), + logbasemeasures = prepare_logbasemeasures_container( + rng, + initial_ef, + nsamples, + supplementary_η, + ), grad = zeros(T, xdim), hessian = zeros(T, xdim, xdim), current_mean = zeros(T, xdim), @@ -72,8 +77,17 @@ function _compute_grad_hess_state!(::Any, state, inplace_projection_argument!) grad_hess!(inplace_projection_argument!, state.grad, state.hessian, state.current_mean) end -function _compute_grad_hess_state!(::Type{ExponentialFamily.NormalMeanVariance}, state, inplace_projection_argument!) - grad_hess!(inplace_projection_argument!, state.grad, state.hessian, state.current_mean[1]) +function _compute_grad_hess_state!( + ::Type{ExponentialFamily.NormalMeanVariance}, + state, + inplace_projection_argument!, +) + grad_hess!( + inplace_projection_argument!, + state.grad, + state.hessian, + state.current_mean[1], + ) end function prepare_state!( @@ -94,13 +108,15 @@ function prepare_state!( _, sample_container = ExponentialFamily.check_logpdf(current_ef, get_samples(state)) one_minus_n_of_supplementary = 1 - length(supplementary_η) nonconstantbasemeasure = - ExponentialFamily.isbasemeasureconstant(current_ef) === ExponentialFamily.NonConstantBaseMeasure() + ExponentialFamily.isbasemeasureconstant(current_ef) === + ExponentialFamily.NonConstantBaseMeasure() # Evaluate logpdf (and base measure if needed) for each sample for the cost for (i, sample) in enumerate(sample_container) if nonconstantbasemeasure @inbounds get_logbasemeasures(state)[i] = - one_minus_n_of_supplementary * ExponentialFamily.logbasemeasure(current_ef, sample) + one_minus_n_of_supplementary * + ExponentialFamily.logbasemeasure(current_ef, sample) end logpdf!(inplace_projection_argument!, view(get_logpdfs(state), i:i), sample) end @@ -125,7 +141,8 @@ function compute_cost( gradlogpartition, logpartition, ) - return dot(gradlogpartition, η) - mean(get_logpdfs(state)) - logpartition + mean(get_logbasemeasures(state)) + return dot(gradlogpartition, η) - mean(get_logpdfs(state)) - logpartition + + mean(get_logbasemeasures(state)) end function compute_gradient!( @@ -148,7 +165,7 @@ function call_objective( M::AbstractManifold, X, p, -) where {J,F,C,P,S <: GaussNewton} +) where {J,F,C,P,S<:GaussNewton} current_ef = convert(ExponentialFamilyDistribution, M, p) current_η = copyto!(get_current_η(objective), getnaturalparameters(current_ef)) @@ -175,23 +192,10 @@ function call_objective( map!(-, current_η, current_η, s_η) end - c = compute_cost( - M, - strategy, - state, - current_η, - gradlogpartition, - logpartition, - ) + c = compute_cost(M, strategy, state, current_η, gradlogpartition, logpartition) - X_nat = compute_gradient!( - M, - strategy, - state, - X, - current_η, - ) + X_nat = compute_gradient!(M, strategy, state, X, current_η) X = jacobian_nat_to_manifold!(M, X, X_nat) X = project!(M, X, p, X) return c, X -end \ No newline at end of file +end diff --git a/src/strategies/bonnet/naive_grad_hess.jl b/src/strategies/bonnet/naive_grad_hess.jl index 5657c7a..39b9b96 100644 --- a/src/strategies/bonnet/naive_grad_hess.jl +++ b/src/strategies/bonnet/naive_grad_hess.jl @@ -24,4 +24,4 @@ function (inplace::NaiveGradHess)(out_grad, out_hess, x) inplace.grad!(out_grad, x) inplace.hess!(out_hess, x) return out_grad, out_hess -end \ No newline at end of file +end diff --git a/src/strategies/bonnet/strategy.jl b/src/strategies/bonnet/strategy.jl index 8960515..e89ecea 100644 --- a/src/strategies/bonnet/strategy.jl +++ b/src/strategies/bonnet/strategy.jl @@ -16,7 +16,7 @@ The following parameters are available: This strategy requires a logpdf function that can be converted to an `InplaceLogpdfGradHess` object. This strategy requires the normal manifold. """ -Base.@kwdef struct BonnetStrategy{S, TL} +Base.@kwdef struct BonnetStrategy{S,TL} nsamples::S = 2000 base_logpdf_type::Type{TL} = InplaceLogpdfGradHess end @@ -29,7 +29,7 @@ preprocess_strategy_argument(::BonnetStrategy, argument::AbstractArray) = error( lazy"The `BonnetStrategy` requires the projection argument to be a callable object (e.g. `Function`) or an `InplaceLogpdfGradHess`. Got `$(typeof(argument))` instead.", ) -Base.@kwdef struct BonnetStrategyState{S, L, LB, G, H, M} +Base.@kwdef struct BonnetStrategyState{S,L,LB,G,H,M} samples::S logpdfs::L logbasemeasures::LB @@ -57,11 +57,12 @@ function create_state!( # Create containers for the BonnetStrategy state nsamples = get_nsamples(strategy) rng = getrng(parameters) - + # Prepare containers following the same pattern as ControlVariateStrategy samples = prepare_samples_container(rng, initial_ef, nsamples, supplementary_η) logpdfs = prepare_logpdfs_container(rng, initial_ef, nsamples, supplementary_η) - logbasemeasures = prepare_logbasemeasures_container(rng, initial_ef, nsamples, supplementary_η) + logbasemeasures = + prepare_logbasemeasures_container(rng, initial_ef, nsamples, supplementary_η) grads = prepare_grads_container(rng, initial_ef, nsamples, supplementary_η) hessians = prepare_hessians_container(rng, initial_ef, nsamples, supplementary_η) current_mean = prepare_current_mean_container(rng, initial_ef, supplementary_η) @@ -86,19 +87,17 @@ function create_state!( ) end -prepare_grads_container(rng, distribution, nsamples, supplementary_η) = - zeros( - paramfloattype(distribution), - length(mean(distribution)), # dimension of the sample space - nsamples, - ) -prepare_hessians_container(rng, distribution, nsamples, supplementary_η) = - zeros( - paramfloattype(distribution), - length(mean(distribution)), # dimension of the sample space - length(mean(distribution)), # dimension of the sample space - nsamples, - ) +prepare_grads_container(rng, distribution, nsamples, supplementary_η) = zeros( + paramfloattype(distribution), + length(mean(distribution)), # dimension of the sample space + nsamples, +) +prepare_hessians_container(rng, distribution, nsamples, supplementary_η) = zeros( + paramfloattype(distribution), + length(mean(distribution)), # dimension of the sample space + length(mean(distribution)), # dimension of the sample space + nsamples, +) prepare_current_mean_container(rng, distribution, supplementary_η) = zeros(paramfloattype(distribution), length(mean(distribution))) @@ -118,14 +117,15 @@ function prepare_state!( Random.seed!(getrng(parameters), getseed(parameters)) Random.rand!(getrng(parameters), current_ef, get_samples(state)) - + _, sample_container = ExponentialFamily.check_logpdf(current_ef, get_samples(state)) inplace_projection_argument! = convert(TL, projection_argument) - + one_minus_n_of_supplementary = 1 - length(supplementary_η) nonconstantbasemeasure = - ExponentialFamily.isbasemeasureconstant(current_ef) === ExponentialFamily.NonConstantBaseMeasure() - + ExponentialFamily.isbasemeasureconstant(current_ef) === + ExponentialFamily.NonConstantBaseMeasure() + # Evaluate logpdf, grad, and hess for each sample for (i, sample) in enumerate(sample_container) # if `basemeasure` is constant we assume that @@ -135,19 +135,20 @@ function prepare_state!( one_minus_n_of_supplementary * ExponentialFamily.logbasemeasure(current_ef, sample) end - + logpdf!(inplace_projection_argument!, view(get_logpdfs(state), i:i), sample) grad_hess!( inplace_projection_argument!, view(get_grads(state), :, i), - view(get_hessians(state), :, :, i), + view(get_hessians(state),:,:,i), sample, ) end - + current_nat_param = getnaturalparameters(current_ef) exponential_family_typetag = ExponentialFamily.exponential_family_typetag(current_ef) - η1, η2 = ExponentialFamily.unpack_parameters(exponential_family_typetag, current_nat_param) + η1, η2 = + ExponentialFamily.unpack_parameters(exponential_family_typetag, current_nat_param) state.current_mean .= (-2η2) \ η1 return state end @@ -158,9 +159,10 @@ function compute_cost( state::BonnetStrategyState, η, gradlogpartition, - logpartition + logpartition, ) - return dot(gradlogpartition, η) - mean(get_logpdfs(state)) - logpartition + mean(get_logbasemeasures(state)) + return dot(gradlogpartition, η) - mean(get_logpdfs(state)) - logpartition + + mean(get_logbasemeasures(state)) end function compute_gradient!( @@ -178,7 +180,7 @@ function bonnet_compute_gradient!( ::BonnetStrategy, state::BonnetStrategyState, X, - η + η, ) mean_grad_vector_η_1 = mean(get_grads(state), dims = 2)[:, 1] mean_hess_vector_η_2 = mean(get_hessians(state), dims = 3)[:, :, 1] @@ -194,8 +196,8 @@ function call_objective( objective::ProjectionCostGradientObjective{J,F,C,P,S}, M::AbstractManifold, X, - p -) where {J,F,C,P,S <: BonnetStrategy} + p, +) where {J,F,C,P,S<:BonnetStrategy} current_ef = convert(ExponentialFamilyDistribution, M, p) current_η = copyto!(get_current_η(objective), getnaturalparameters(current_ef)) @@ -224,23 +226,10 @@ function call_objective( map!(-, current_η, current_η, s_η) end - c = compute_cost( - M, - strategy, - state, - current_η, - gradlogpartition, - logpartition - ) + c = compute_cost(M, strategy, state, current_η, gradlogpartition, logpartition) - X_nat = compute_gradient!( - M, - strategy, - state, - X, - current_η, - ) + X_nat = compute_gradient!(M, strategy, state, X, current_η) X = jacobian_nat_to_manifold!(M, X, X_nat) X = project!(M, X, p, X) return c, X -end \ No newline at end of file +end diff --git a/src/strategies/control_variate.jl b/src/strategies/control_variate.jl index a8c0c9c..9cdb7ef 100644 --- a/src/strategies/control_variate.jl +++ b/src/strategies/control_variate.jl @@ -15,10 +15,10 @@ The following parameters are available: !!! note This strategy requires a function as an argument for `project_to` and cannot project a collection of samples. Use `MLEStrategy` to project a collection of samples. """ -Base.@kwdef struct ControlVariateStrategy{S, B, TL} +Base.@kwdef struct ControlVariateStrategy{S,B,TL} nsamples::S = 2000 buffer::B = Bumper.SlabBuffer() - base_logpdf_type::Type{TL} = InplaceLogpdf + base_logpdf_type::Type{TL} = InplaceLogpdf end get_nsamples(strategy::ControlVariateStrategy) = strategy.nsamples @@ -28,8 +28,10 @@ function Base.:(==)(a::ControlVariateStrategy, b::ControlVariateStrategy)::Bool return get_nsamples(a) == get_nsamples(b) && get_buffer(a) == get_buffer(b) end -preprocess_strategy_argument(strategy::ControlVariateStrategy{S,B,TL}, argument::Any) where {S,B,TL} = - (strategy, convert(TL, argument)) +preprocess_strategy_argument( + strategy::ControlVariateStrategy{S,B,TL}, + argument::Any, +) where {S,B,TL} = (strategy, convert(TL, argument)) preprocess_strategy_argument(::ControlVariateStrategy, argument::AbstractArray) = error( lazy"The `ControlVariateStrategy` requires the projection argument to be a callable object (e.g. `Function`). Got `$(typeof(argument))` instead.", ) @@ -356,4 +358,4 @@ function control_variate_cov_buffered!(buffer, Z, X, Y) nothing end return Z -end \ No newline at end of file +end diff --git a/src/strategies/default.jl b/src/strategies/default.jl index 099562f..4d718ff 100644 --- a/src/strategies/default.jl +++ b/src/strategies/default.jl @@ -15,4 +15,4 @@ struct DefaultStrategy end preprocess_strategy_argument(::DefaultStrategy, argument::AbstractArray) = preprocess_strategy_argument(MLEStrategy(), argument) preprocess_strategy_argument(::DefaultStrategy, argument::Any) = - preprocess_strategy_argument(ControlVariateStrategy(), argument) \ No newline at end of file + preprocess_strategy_argument(ControlVariateStrategy(), argument) diff --git a/src/strategies/mle.jl b/src/strategies/mle.jl index 2b5eed3..6d50716 100644 --- a/src/strategies/mle.jl +++ b/src/strategies/mle.jl @@ -77,7 +77,11 @@ end function (fn::MLETargetFn)(η) # This function essentially computes the negative average of `logpdf` of all provided `samples` # with the distribution defined in `η` - ef = convert(ExponentialFamilyDistribution, fn.manifold, ExponentialFamilyManifolds.partition_point(fn.manifold, η)) + ef = convert( + ExponentialFamilyDistribution, + fn.manifold, + ExponentialFamilyManifolds.partition_point(fn.manifold, η), + ) _, samples_container = ExponentialFamily.check_logpdf(ef, fn.samples) # We use precomputed `sufficientstatistics` since in this strategy the `samples` are fixed sufficientstatistics_container = eachcol(fn.sufficientstatistics) @@ -99,7 +103,7 @@ function compute_cost( _, _, _, -) +) return gettargetfn(state)(η) end @@ -113,10 +117,14 @@ function compute_gradient!( _, inv_fisher, ) - ef = convert(ExponentialFamilyDistribution, M, ExponentialFamilyManifolds.partition_point(M, η)) + ef = convert( + ExponentialFamilyDistribution, + M, + ExponentialFamilyManifolds.partition_point(M, η), + ) targetfn = gettargetfn(state) mean_sufficient_stats = mean(targetfn.sufficientstatistics, dims = 2) G = -view(mean_sufficient_stats, :, 1) + gradlogpartition(ef) X = mul!(X, inv_fisher, G) return X -end \ No newline at end of file +end diff --git a/test/aqua_tests.jl b/test/aqua_tests.jl index f1ce272..035bebe 100644 --- a/test/aqua_tests.jl +++ b/test/aqua_tests.jl @@ -1,5 +1,9 @@ @testitem "Aqua: Auto QUality Assurance" begin using Aqua, ExponentialFamilyProjection - Aqua.test_all(ExponentialFamilyProjection; ambiguities = false, deps_compat = (; check_extras = false, check_weakdeps = true)) -end \ No newline at end of file + Aqua.test_all( + ExponentialFamilyProjection; + ambiguities = false, + deps_compat = (; check_extras = false, check_weakdeps = true), + ) +end diff --git a/test/manopt/bounded_norm_update_rule_tests.jl b/test/manopt/bounded_norm_update_rule_tests.jl index 6cd4c90..d1d6a03 100644 --- a/test/manopt/bounded_norm_update_rule_tests.jl +++ b/test/manopt/bounded_norm_update_rule_tests.jl @@ -96,4 +96,4 @@ end end -end \ No newline at end of file +end diff --git a/test/manopt_setuptests.jl b/test/manopt_setuptests.jl index 4d42663..2d28178 100644 --- a/test/manopt_setuptests.jl +++ b/test/manopt_setuptests.jl @@ -22,4 +22,4 @@ struct ConstantStepsizeNonAllocating{T} <: Stepsize end function (cs::ConstantStepsizeNonAllocating)(args...; kwargs...) return cs.stepsize -end \ No newline at end of file +end diff --git a/test/projection/batch_logpdf.jl b/test/projection/batch_logpdf.jl index 165df9a..2f63483 100644 --- a/test/projection/batch_logpdf.jl +++ b/test/projection/batch_logpdf.jl @@ -2,14 +2,14 @@ using BayesBase struct BatchLogpdf{F,N} batch_logpdf::F - + function BatchLogpdf{N}(batch_logpdf::F) where {F,N} return new{F,N}(batch_logpdf) end end # Constructor with batch size -function BatchLogpdf(batch_logpdf::F; batch_size::Int=100) where F +function BatchLogpdf(batch_logpdf::F; batch_size::Int = 100) where {F} return BatchLogpdf{batch_size}(batch_logpdf) end @@ -17,17 +17,17 @@ end function (b::BatchLogpdf{F,N})(out::AbstractVector, samples) where {F,N} n_samples = length(samples) n_batches = ceil(Int, n_samples / N) - + # Process samples in batches - for i in 1:n_batches + for i = 1:n_batches start_idx = (i-1) * N + 1 end_idx = min(i * N, n_samples) batch_slice = start_idx:end_idx - + # Process current batch - view(out, batch_slice).= b.batch_logpdf(view(samples, batch_slice)) + view(out, batch_slice) .= b.batch_logpdf(view(samples, batch_slice)) end - + return out end @@ -47,10 +47,10 @@ function Base.convert(::Type{BatchLogpdf}, logpdf_fn) return BatchLogpdf{100}(logpdf_fn) # Default batch size end -function Base.convert(::Type{BatchLogpdf{N}}, logpdf_fn) where N +function Base.convert(::Type{BatchLogpdf{N}}, logpdf_fn) where {N} return BatchLogpdf{N}(logpdf_fn) end function Base.convert(::Type{BatchLogpdf}, logpdf_fn::BatchLogpdf{F,N}) where {F,N} return logpdf_fn -end \ No newline at end of file +end diff --git a/test/projection/helpers/debug.jl b/test/projection/helpers/debug.jl index 9108c8c..87bfc89 100644 --- a/test/projection/helpers/debug.jl +++ b/test/projection/helpers/debug.jl @@ -29,7 +29,7 @@ if do_debug lines = split(debug_string, '\n') @test length(lines) == n_iterations + 2 - @test all(map(line -> occursin(r"f\(x\): -?\d+\.\d+", line), lines[2:end-1])) + @test all(map(line -> occursin(r"f\(x\): -?\d+\.\d+", line), lines[2:(end-1)])) else @test debug_string == "" end @@ -52,4 +52,4 @@ test_projection_with_debug(n, true, true) end end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_bernoulli_tests.jl b/test/projection/projected_to_bernoulli_tests.jl index 5b23169..0e5ef3b 100644 --- a/test/projection/projected_to_bernoulli_tests.jl +++ b/test/projection/projected_to_bernoulli_tests.jl @@ -31,11 +31,21 @@ end include("./projected_to_setuptests.jl") @testset let distribution = ProductOf(Bernoulli(0.3), Bernoulli(0.65)) - @test test_projection_convergence(distribution, to = Bernoulli, dims = (), conditioner = nothing) + @test test_projection_convergence( + distribution, + to = Bernoulli, + dims = (), + conditioner = nothing, + ) end @testset let distribution = ProductOf(Bernoulli(0.5), Bernoulli(0.95)) - @test test_projection_convergence(distribution, to = Bernoulli, dims = (), conditioner = nothing) + @test test_projection_convergence( + distribution, + to = Bernoulli, + dims = (), + conditioner = nothing, + ) end end @@ -66,5 +76,5 @@ end @test test_projection_mle(distribution) end - -end \ No newline at end of file + +end diff --git a/test/projection/projected_to_beta_tests.jl b/test/projection/projected_to_beta_tests.jl index 976094e..e851037 100644 --- a/test/projection/projected_to_beta_tests.jl +++ b/test/projection/projected_to_beta_tests.jl @@ -120,4 +120,4 @@ end @testset let distribution = Beta(1000, 1000) @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_binomial_tests.jl b/test/projection/projected_to_binomial_tests.jl index c09c2a6..02fc447 100644 --- a/test/projection/projected_to_binomial_tests.jl +++ b/test/projection/projected_to_binomial_tests.jl @@ -12,7 +12,7 @@ @test test_projection_convergence(distribution) end - @testset let distribution = Binomial(2, 0.9) + @testset let distribution = Binomial(2, 0.9) @test test_projection_convergence(distribution) end @@ -35,7 +35,11 @@ end end @testset let distribution = Poisson(5.3) - @test_throws AssertionError test_projection_convergence(distribution, to = Binomial, conditioner = -20) + @test_throws AssertionError test_projection_convergence( + distribution, + to = Binomial, + conditioner = -20, + ) end end @@ -52,7 +56,11 @@ end end @testset let distribution = Bernoulli(0.6) - @test_throws AssertionError test_projection_convergence(distribution, to = Binomial, conditioner = -1) + @test_throws AssertionError test_projection_convergence( + distribution, + to = Binomial, + conditioner = -1, + ) end @testset let distribution = Bernoulli(0.3) @@ -82,7 +90,7 @@ end @test test_projection_mle(distribution) end - @testset let distribution = Binomial(2, 0.9) + @testset let distribution = Binomial(2, 0.9) @test test_projection_mle(distribution) end diff --git a/test/projection/projected_to_chisq_tests.jl b/test/projection/projected_to_chisq_tests.jl index e95905b..3ec9c16 100644 --- a/test/projection/projected_to_chisq_tests.jl +++ b/test/projection/projected_to_chisq_tests.jl @@ -34,4 +34,4 @@ end @testset let distribution = Chisq(10.0) @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_dirichlet_tests.jl b/test/projection/projected_to_dirichlet_tests.jl index 345d0c4..71d18e5 100644 --- a/test/projection/projected_to_dirichlet_tests.jl +++ b/test/projection/projected_to_dirichlet_tests.jl @@ -4,23 +4,23 @@ include("./projected_to_setuptests.jl") - @testset let distribution = Dirichlet([ 1.0, 1.0 ]) + @testset let distribution = Dirichlet([1.0, 1.0]) @test test_projection_convergence(distribution) end - @testset let distribution = Dirichlet([ 0.5, 2.0 ]) + @testset let distribution = Dirichlet([0.5, 2.0]) @test test_projection_convergence(distribution) end - @testset let distribution = Dirichlet([ 2.0, 5.0 ]) + @testset let distribution = Dirichlet([2.0, 5.0]) @test test_projection_convergence(distribution) end - @testset let distribution = Dirichlet([ 3.14, 2.71, 6.81 ]) + @testset let distribution = Dirichlet([3.14, 2.71, 6.81]) @test test_projection_convergence(distribution) end - @testset let distribution = Dirichlet([ 2.0, 5.0, 7.0, 0.5 ]) + @testset let distribution = Dirichlet([2.0, 5.0, 7.0, 0.5]) @test test_projection_convergence(distribution) end @@ -32,24 +32,24 @@ end include("./projected_to_setuptests.jl") - @testset let distribution = Dirichlet([ 1.0, 1.0 ]) + @testset let distribution = Dirichlet([1.0, 1.0]) @test test_projection_mle(distribution) end - @testset let distribution = Dirichlet([ 0.5, 2.0 ]) + @testset let distribution = Dirichlet([0.5, 2.0]) @test test_projection_mle(distribution) end - @testset let distribution = Dirichlet([ 2.0, 5.0 ]) + @testset let distribution = Dirichlet([2.0, 5.0]) @test test_projection_mle(distribution) end - @testset let distribution = Dirichlet([ 3.14, 2.71, 6.81 ]) + @testset let distribution = Dirichlet([3.14, 2.71, 6.81]) @test test_projection_mle(distribution) end - @testset let distribution = Dirichlet([ 2.0, 5.0, 7.0, 0.5 ]) + @testset let distribution = Dirichlet([2.0, 5.0, 7.0, 0.5]) @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_exponential_tests.jl b/test/projection/projected_to_exponential_tests.jl index ab00371..1823e83 100644 --- a/test/projection/projected_to_exponential_tests.jl +++ b/test/projection/projected_to_exponential_tests.jl @@ -50,4 +50,4 @@ end @testset let distribution = Exponential(7.41) @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_geometric_tests.jl b/test/projection/projected_to_geometric_tests.jl index d7d6982..f2470c3 100644 --- a/test/projection/projected_to_geometric_tests.jl +++ b/test/projection/projected_to_geometric_tests.jl @@ -50,4 +50,4 @@ end @testset let distribution = Geometric(0.9) @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_inverse_gamma_tests.jl b/test/projection/projected_to_inverse_gamma_tests.jl index f1264d9..ee43447 100644 --- a/test/projection/projected_to_inverse_gamma_tests.jl +++ b/test/projection/projected_to_inverse_gamma_tests.jl @@ -13,7 +13,7 @@ end @testset let distribution = InverseGamma(1, 0.5) - @test test_projection_convergence(distribution, nsamples_range=100:200:2000) + @test test_projection_convergence(distribution, nsamples_range = 100:200:2000) end end @@ -34,4 +34,4 @@ end @testset let distribution = InverseGamma(1, 10) @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_laplace_tests.jl b/test/projection/projected_to_laplace_tests.jl index 51dda1e..d29ea7b 100644 --- a/test/projection/projected_to_laplace_tests.jl +++ b/test/projection/projected_to_laplace_tests.jl @@ -35,4 +35,4 @@ end @testset let distribution = Laplace(-3.14, 2.71) @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_lognormal_tests.jl b/test/projection/projected_to_lognormal_tests.jl index 8565f93..aba2830 100644 --- a/test/projection/projected_to_lognormal_tests.jl +++ b/test/projection/projected_to_lognormal_tests.jl @@ -34,4 +34,4 @@ end @testset let distribution = LogNormal(-3.14, 2.71) @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_normal_tests.jl b/test/projection/projected_to_normal_tests.jl index a4890f4..8403e9f 100644 --- a/test/projection/projected_to_normal_tests.jl +++ b/test/projection/projected_to_normal_tests.jl @@ -122,7 +122,7 @@ end distribution, nsamples_range = 500:100:2000, nsamples_niterations = 1000, - nsamples_required_accuracy = 1e-1 + nsamples_required_accuracy = 1e-1, ) end @@ -131,7 +131,7 @@ end distribution, nsamples_range = 500:100:2000, nsamples_niterations = 1000, - nsamples_required_accuracy = 1e-1 + nsamples_required_accuracy = 1e-1, ) end @@ -140,12 +140,16 @@ end distribution, nsamples_range = 500:100:2000, nsamples_niterations = 1000, - nsamples_required_accuracy = 1e-1 + nsamples_required_accuracy = 1e-1, ) end - @testset let distribution = MvNormalMeanCovariance(10randn(StableRNG(42), 4), 10rand(StableRNG(43), 4)) - @test test_bonnet_projection_convergence(distribution, niterations_range = 500:100:2000) + @testset let distribution = + MvNormalMeanCovariance(10randn(StableRNG(42), 4), 10rand(StableRNG(43), 4)) + @test test_bonnet_projection_convergence( + distribution, + niterations_range = 500:100:2000, + ) end end end @@ -169,8 +173,12 @@ end @test test_gaussnewton_projection_convergence(distribution) end - @testset let distribution = MvNormalMeanCovariance(10randn(StableRNG(42), 4), 10rand(StableRNG(43), 4)) - @test test_gaussnewton_projection_convergence(distribution, niterations_range = 500:100:2000) + @testset let distribution = + MvNormalMeanCovariance(10randn(StableRNG(42), 4), 10rand(StableRNG(43), 4)) + @test test_gaussnewton_projection_convergence( + distribution, + niterations_range = 500:100:2000, + ) end end end @@ -247,9 +255,9 @@ end end @testset let distribution = ProductOf( - MvNormalMeanScalePrecision(ones(20), 2), - MvNormalMeanScalePrecision(ones(20), 3), - ) + MvNormalMeanScalePrecision(ones(20), 2), + MvNormalMeanScalePrecision(ones(20), 3), + ) @test test_projection_convergence( distribution, to = MvNormalMeanScalePrecision, @@ -258,8 +266,8 @@ end nsamples_niterations = 6000, nsamples_range = 1000:1000:6000, niterations_range = 400:100:1000, - nsamples_required_accuracy=0.3, - niterations_required_accuracy=0.3 + nsamples_required_accuracy = 0.3, + niterations_required_accuracy = 0.3, ) end @@ -322,4 +330,4 @@ end @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_poisson_tests.jl b/test/projection/projected_to_poisson_tests.jl index 21be8a4..efb40ab 100644 --- a/test/projection/projected_to_poisson_tests.jl +++ b/test/projection/projected_to_poisson_tests.jl @@ -9,7 +9,11 @@ end @testset let distribution = Poisson(10) - @test test_projection_convergence(distribution, nsamples_range = 500:200:4000, niterations_nsamples = 700) + @test test_projection_convergence( + distribution, + nsamples_range = 500:200:4000, + niterations_nsamples = 700, + ) end @testset let distribution = Poisson(0.5) @@ -50,4 +54,4 @@ end @testset let distribution = Poisson(0.5) @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_rayleigh_tests.jl b/test/projection/projected_to_rayleigh_tests.jl index c5684f5..1fa6f87 100644 --- a/test/projection/projected_to_rayleigh_tests.jl +++ b/test/projection/projected_to_rayleigh_tests.jl @@ -33,17 +33,11 @@ end include("./projected_to_setuptests.jl") @testset let distribution = Gamma(4, 10) - @test test_projection_convergence( - distribution, - to = Rayleigh - ) + @test test_projection_convergence(distribution, to = Rayleigh) end @testset let distribution = Gamma(40, 10) - @test_broken test_projection_convergence( - distribution, - to = Rayleigh - ) + @test_broken test_projection_convergence(distribution, to = Rayleigh) end end @@ -54,10 +48,7 @@ end include("./projected_to_setuptests.jl") @testset let distribution = LogNormal(0.1, 1) - @test_broken test_projection_convergence( - distribution, - to = Rayleigh - ) + @test_broken test_projection_convergence(distribution, to = Rayleigh) end end @@ -86,4 +77,4 @@ end @testset let distribution = Rayleigh(100.26) @test test_projection_mle(distribution) end -end \ No newline at end of file +end diff --git a/test/projection/projected_to_setuptests.jl b/test/projection/projected_to_setuptests.jl index 48f5224..49ada7e 100644 --- a/test/projection/projected_to_setuptests.jl +++ b/test/projection/projected_to_setuptests.jl @@ -1,4 +1,5 @@ -using ExponentialFamily, Distributions, BayesBase, StableRNGs, RollingFunctions, Manopt, ForwardDiff +using ExponentialFamily, + Distributions, BayesBase, StableRNGs, RollingFunctions, Manopt, ForwardDiff import ExponentialFamilyProjection: InplaceLogpdfGradHess, BonnetStrategy, GaussNewton function test_projection_mle( @@ -368,7 +369,7 @@ end # Helper function to create InplaceLogpdfGradHess for BonnetStrategy testing function create_bonnet_target(distribution) - + if distribution isa NormalMeanVariance # Univariate case logpdf_fn = (out, x) -> (out[1] = logpdf(distribution, x)) @@ -376,7 +377,12 @@ function create_bonnet_target(distribution) (out, x) -> (out[1] = ForwardDiff.derivative(x -> logpdf(distribution, x), x)) end hess_fn = let ForwardDiff = ForwardDiff - (out, x) -> (out[1] = ForwardDiff.derivative(x -> ForwardDiff.derivative(x -> logpdf(distribution, x), x), x)) + (out, x) -> ( + out[1] = ForwardDiff.derivative( + x -> ForwardDiff.derivative(x -> logpdf(distribution, x), x), + x, + ) + ) end else # Multivariate case @@ -405,12 +411,14 @@ function test_bonnet_projection_convergence( nsamples_rng = StableRNG(42), kwargs..., ) - T = ismissing(to) ? + T = + ismissing(to) ? ExponentialFamily.exponential_family_typetag( convert(ExponentialFamilyDistribution, distribution), ) : to dims = ismissing(dims) ? size(rand(StableRNG(42), distribution)) : dims - conditioner = ismissing(conditioner) ? + conditioner = + ismissing(conditioner) ? getconditioner(convert(ExponentialFamilyDistribution, distribution)) : conditioner bonnet_target = create_bonnet_target(distribution) @@ -423,7 +431,8 @@ function test_bonnet_projection_convergence( stepsize = nsamples_stepsize, seed = rand(nsamples_rng, UInt), ) - projection = ProjectedTo(T, dims..., parameters = parameters, conditioner = conditioner) + projection = + ProjectedTo(T, dims..., parameters = parameters, conditioner = conditioner) approximated = project_to(projection, bonnet_target) divergence = test_convergence_metric(approximated, distribution) return divergence, approximated @@ -457,7 +466,7 @@ function test_bonnet_niterations_convergence( niterations_required_accuracy = 1e-1, niterations_stepsize = ConstantLength(0.1), niterations_rng = StableRNG(42), - kwargs... + kwargs..., ) bonnet_target = create_bonnet_target(distribution) @@ -508,12 +517,14 @@ function test_gaussnewton_projection_convergence( niterations_rng = StableRNG(42), kwargs..., ) - T = ismissing(to) ? + T = + ismissing(to) ? ExponentialFamily.exponential_family_typetag( convert(ExponentialFamilyDistribution, distribution), ) : to dims = ismissing(dims) ? size(rand(StableRNG(42), distribution)) : dims - conditioner = ismissing(conditioner) ? + conditioner = + ismissing(conditioner) ? getconditioner(convert(ExponentialFamilyDistribution, distribution)) : conditioner target = create_bonnet_target(distribution) @@ -526,7 +537,8 @@ function test_gaussnewton_projection_convergence( stepsize = niterations_stepsize, seed = rand(niterations_rng, UInt), ) - projection = ProjectedTo(T, dims..., parameters = parameters, conditioner = conditioner) + projection = + ProjectedTo(T, dims..., parameters = parameters, conditioner = conditioner) approximated = project_to(projection, target) divergence = test_convergence_metric(approximated, distribution) return divergence, approximated @@ -550,4 +562,3 @@ function test_gaussnewton_projection_convergence( return test_required_accuracy && test_convergence end - diff --git a/test/projection/projected_to_tests.jl b/test/projection/projected_to_tests.jl index 2d5d0e6..05b6231 100644 --- a/test/projection/projected_to_tests.jl +++ b/test/projection/projected_to_tests.jl @@ -518,7 +518,7 @@ end nothing, ) initialpoint = rand(manifold) - direction = MomentumGradient(p=initialpoint) + direction = MomentumGradient(p = initialpoint) momentum_parameters = ProjectionParameters(direction = direction, niterations = 1000, tolerance = 1e-8) @@ -576,15 +576,13 @@ end initialpoint = rand(rng, manifold) - direction = ExponentialFamilyProjection.BoundedNormUpdateRule(10.0; - direction = Manopt.MomentumGradient(p=initialpoint) + direction = ExponentialFamilyProjection.BoundedNormUpdateRule( + 10.0; + direction = Manopt.MomentumGradient(p = initialpoint), ) - combined_parameters = ProjectionParameters( - direction = direction, - niterations = 1000, - tolerance = 1e-8 - ) + combined_parameters = + ProjectionParameters(direction = direction, niterations = 1000, tolerance = 1e-8) projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters = combined_parameters) approximated = project_to(projection, samples, initialpoint = initialpoint) @@ -608,20 +606,18 @@ end ) initialpoint = rand(manifold) - update_rules = [ - Nesterov(), - MomentumGradient(momentum=0.9), - Manopt.IdentityUpdateRule() - ] + update_rules = + [Nesterov(), MomentumGradient(momentum = 0.9), Manopt.IdentityUpdateRule()] for update_rule in update_rules - direction = ExponentialFamilyProjection.BoundedNormUpdateRule(1000.0; - direction = update_rule + direction = ExponentialFamilyProjection.BoundedNormUpdateRule( + 1000.0; + direction = update_rule, ) - momentum_parameters = - ProjectionParameters(direction = direction, tolerance = 1e-8) + momentum_parameters = ProjectionParameters(direction = direction, tolerance = 1e-8) - projection = ProjectedTo(MvNormalMeanCovariance, 3, parameters = momentum_parameters) + projection = + ProjectedTo(MvNormalMeanCovariance, 3, parameters = momentum_parameters) approximated = project_to(projection, logp, initialpoint = initialpoint) @@ -637,7 +633,7 @@ end logp = (x) -> logpdf(true_dist, x) - for i in 1:10 + for i = 1:10 manifold = ExponentialFamilyManifolds.get_natural_manifold( MvNormalMeanCovariance, (i,), @@ -651,7 +647,11 @@ end logp, initialpoint = initialpoint, ) - @test ExponentialFamilyProjection.check_inputs(projection, logp; initialpoint=nothing) === nothing + @test ExponentialFamilyProjection.check_inputs( + projection, + logp; + initialpoint = nothing, + ) === nothing end end @@ -668,7 +668,7 @@ end include("batch_logpdf.jl") # Create a delayed normal distribution - function create_delayed_normal(delay_seconds=0.1) + function create_delayed_normal(delay_seconds = 0.1) dist = NormalMeanVariance(0.0, 1.0) return function delayed_logpdf(x) sleep(delay_seconds) # expansive operation (for example moving data to GPU) @@ -676,7 +676,7 @@ end end end - delay = 0.0001 + delay = 0.0001 batch_logpdf = BatchLogpdf(create_delayed_normal(delay)) regular_inplace = convert(InplaceLogpdf, create_delayed_normal(delay)); @@ -695,49 +695,49 @@ end bench_regular = @benchmark regular_inplace(out2, samples) seconds=1 bench_batch = @benchmark batch_logpdf(out1, samples) seconds=1 - @test isapprox(mean(bench_converted.times), mean(bench_regular.times), rtol=1e-1) + @test isapprox(mean(bench_converted.times), mean(bench_regular.times), rtol = 1e-1) # This is not a correctness test, but a performance test. # On Julia 1.11, the batch logpdf is faster than on Julia 1.10. @static if VERSION >= v"1.11" @test mean(bench_batch.times) < mean(bench_regular.times)/5 - else + else @test mean(bench_batch.times) < mean(bench_regular.times) end # Create strategies with different base_logpdf_type batch_size = 10 strategy_batch = ExponentialFamilyProjection.ControlVariateStrategy( - nsamples=nsamples, - base_logpdf_type=BatchLogpdf{batch_size} # Ensure we're using the same buffer type + nsamples = nsamples, + base_logpdf_type = BatchLogpdf{batch_size}, # Ensure we're using the same buffer type ) strategy_inplace = ExponentialFamilyProjection.ControlVariateStrategy( - nsamples=nsamples, - base_logpdf_type=InplaceLogpdf + nsamples = nsamples, + base_logpdf_type = InplaceLogpdf, ) projection_batch = ProjectedTo( - NormalMeanVariance, - parameters=ProjectionParameters( - niterations=3, - tolerance=1e-1, - strategy=strategy_batch - ) + NormalMeanVariance, + parameters = ProjectionParameters( + niterations = 3, + tolerance = 1e-1, + strategy = strategy_batch, + ), ) projection_inplace = ProjectedTo( - NormalMeanVariance, - parameters=ProjectionParameters( - niterations=3, - tolerance=1e-1, - strategy=strategy_inplace - ) + NormalMeanVariance, + parameters = ProjectionParameters( + niterations = 3, + tolerance = 1e-1, + strategy = strategy_inplace, + ), ) # Add counter to track number of logpdf calls ncalls = 0 - target_logpdf = function(x) + target_logpdf = function (x) global ncalls += 1 sleep(delay) return logpdf(NormalMeanVariance(0.0, 1.0), x) @@ -753,4 +753,4 @@ end ncalls_inplace = ncalls @test ncalls_batch == ncalls_inplace // batch_size @test result_batch ≈ result_inplace -end \ No newline at end of file +end diff --git a/test/projection/projected_to_weibull_tests.jl b/test/projection/projected_to_weibull_tests.jl index 4fc8e65..70ed149 100644 --- a/test/projection/projected_to_weibull_tests.jl +++ b/test/projection/projected_to_weibull_tests.jl @@ -25,35 +25,23 @@ end include("./projected_to_setuptests.jl") @testset let distribution = Exponential(0.1) - @test test_projection_convergence( - distribution, - to = Weibull, - conditioner = 1.0 - ) + @test test_projection_convergence(distribution, to = Weibull, conditioner = 1.0) end @testset let distribution = Exponential(0.1) @test_throws AssertionError test_projection_convergence( distribution, to = Weibull, - conditioner = -1.0 + conditioner = -1.0, ) end @testset let distribution = Exponential(10.1) - @test test_projection_convergence( - distribution, - to = Weibull, - conditioner = 1.0 - ) + @test test_projection_convergence(distribution, to = Weibull, conditioner = 1.0) end @testset let distribution = Exponential(100.32) - @test test_projection_convergence( - distribution, - to = Weibull, - conditioner = 1.0 - ) + @test test_projection_convergence(distribution, to = Weibull, conditioner = 1.0) end end @@ -64,35 +52,23 @@ end include("./projected_to_setuptests.jl") @testset let distribution = Rayleigh(0.1) - @test test_projection_convergence( - distribution, - to = Weibull, - conditioner = 2.0 - ) + @test test_projection_convergence(distribution, to = Weibull, conditioner = 2.0) end @testset let distribution = Rayleigh(0.1) @test_throws AssertionError test_projection_convergence( distribution, to = Weibull, - conditioner = -2.0 + conditioner = -2.0, ) end @testset let distribution = Rayleigh(10.1) - @test test_projection_convergence( - distribution, - to = Weibull, - conditioner = 2.0 - ) + @test test_projection_convergence(distribution, to = Weibull, conditioner = 2.0) end @testset let distribution = Rayleigh(100.32) - @test test_projection_convergence( - distribution, - to = Weibull, - conditioner = 2.0 - ) + @test test_projection_convergence(distribution, to = Weibull, conditioner = 2.0) end end diff --git a/test/runtests.jl b/test/runtests.jl index e8e2600..ee5bd0a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,12 @@ using Aqua, Hwloc, ReTestItems, ExponentialFamilyProjection, Random Random.seed!(42) if get(ENV, "RUN_AQUA", "true") == "true" - Aqua.test_all(ExponentialFamilyProjection; ambiguities = false, piracies = false, deps_compat = (; check_extras = false, check_weakdeps = true)) + Aqua.test_all( + ExponentialFamilyProjection; + ambiguities = false, + piracies = false, + deps_compat = (; check_extras = false, check_weakdeps = true), + ) end nthreads, ncores = Hwloc.num_virtual_cores(), Hwloc.num_physical_cores() @@ -16,7 +21,12 @@ pkg_root = dirname(pathof(ExponentialFamilyProjection)) |> dirname test_root = joinpath(pkg_root, "test") if isempty(ARGS) - runtests(ExponentialFamilyProjection; nworkers = ncores, nworker_threads = nworker_threads, memory_threshold = memory_threshold) + runtests( + ExponentialFamilyProjection; + nworkers = ncores, + nworker_threads = nworker_threads, + memory_threshold = memory_threshold, + ) else for arg in ARGS # Translate colon syntax (e.g., rules:normal_mean_variance → rules/normal_mean_variance) @@ -30,7 +40,12 @@ else if path !== nothing selected_path = paths[path] @info "Running selective tests from $selected_path" - runtests(selected_path; nworkers = ncores, nworker_threads = nworker_threads, memory_threshold = memory_threshold) + runtests( + selected_path; + nworkers = ncores, + nworker_threads = nworker_threads, + memory_threshold = memory_threshold, + ) else @warn "Test target not found: $arg" end diff --git a/test/strategies/bonnet_tests.jl b/test/strategies/bonnet_tests.jl index 27cc74b..c2b96b2 100644 --- a/test/strategies/bonnet_tests.jl +++ b/test/strategies/bonnet_tests.jl @@ -3,13 +3,13 @@ using LinearAlgebra # Define simple functions for testing - logpdf_fn! = (out, x) -> out .= -(x .- 1).^2 + logpdf_fn! = (out, x) -> out .= -(x .- 1) .^ 2 grad_fn! = (out, x) -> out .= -2 .* (x .- 1) hess_fn! = (out, x) -> out .= -2 .* I # Test construction inplace_struct = InplaceLogpdfGradHess(logpdf_fn!, grad_fn!, hess_fn!) - + @test inplace_struct isa InplaceLogpdfGradHess @test inplace_struct.logpdf! === logpdf_fn! @test inplace_struct.grad_hess! isa ExponentialFamilyProjection.NaiveGradHess @@ -53,18 +53,18 @@ end @testitem "InplaceLogpdfGradHess univariate case" begin import ExponentialFamilyProjection: InplaceLogpdfGradHess - + # Univariate quadratic: -(x-1)² # logpdf: -(x-1)² # grad: -2(x-1) # hess: -2 - + logpdf_fn! = (out, x) -> (out[1] = -(x - 1)^2) grad_fn! = (out, x) -> (out[1] = -2 * (x - 1)) hess_fn! = (out, x) -> (out[1] = -2) - + inplace_struct = InplaceLogpdfGradHess(logpdf_fn!, grad_fn!, hess_fn!) - + # Test points test_points = [0.0, 1.0, 2.0, 3.0] @@ -77,7 +77,7 @@ end grad_out, hess_out = zeros(1), zeros(1) grad_out, hess_out = inplace_struct.grad_hess!(grad_out, hess_out, x) - + expected_grad = -2 * (x - 1) expected_hess = -2 @test grad_out[1] ≈ expected_grad @@ -88,12 +88,12 @@ end @testitem "InplaceLogpdfGradHess multivariate case" begin import ExponentialFamilyProjection: InplaceLogpdfGradHess using LinearAlgebra - + # Multivariate quadratic: -(x₁-1)² - (x₂-2)² # logpdf: -(x₁-1)² - (x₂-2)² # grad: [-2(x₁-1), -2(x₂-2)] # hess: [[-2, 0], [0, -2]] - + logpdf_fn! = (out, x) -> (out[1] = -(x[1] - 1)^2 - (x[2] - 2)^2) grad_fn! = (out, x) -> begin out[1] = -2 * (x[1] - 1) @@ -105,24 +105,24 @@ end out[2, 1] = 0 out[2, 2] = -2 end - + inplace_struct = InplaceLogpdfGradHess(logpdf_fn!, grad_fn!, hess_fn!) - + # Test points test_points = [ [0.0, 0.0], [1.0, 2.0], # optimal point [2.0, 3.0], - [0.5, 1.5] + [0.5, 1.5], ] - + for x in test_points # Test logpdf logpdf_out = zeros(1) ExponentialFamilyProjection.logpdf!(inplace_struct, logpdf_out, x) expected_logpdf = -(x[1] - 1)^2 - (x[2] - 2)^2 @test logpdf_out[1] ≈ expected_logpdf - + # Test gradient and hessian together grad_out, hess_out = inplace_struct.grad_hess!(zeros(2), zeros(2, 2), x) expected_grad = [-2 * (x[1] - 1), -2 * (x[2] - 2)] @@ -135,46 +135,46 @@ end @testitem "InplaceLogpdfGradHess higher dimensional case" begin import ExponentialFamilyProjection: InplaceLogpdfGradHess using LinearAlgebra - + # 3D case: -(x₁-1)² - (x₂-2)² - (x₃-3)² dim = 3 targets = [1.0, 2.0, 3.0] - + logpdf_fn! = (out, x) -> begin - out[1] = -sum((x[i] - targets[i])^2 for i in 1:dim) + out[1] = -sum((x[i] - targets[i])^2 for i = 1:dim) end grad_fn! = (out, x) -> begin - for i in 1:dim + for i = 1:dim out[i] = -2 * (x[i] - targets[i]) end end hess_fn! = (out, x) -> begin fill!(out, 0) - for i in 1:dim + for i = 1:dim out[i, i] = -2 end end - + inplace_struct = InplaceLogpdfGradHess(logpdf_fn!, grad_fn!, hess_fn!) - + # Test points test_points = [ [0.0, 0.0, 0.0], [1.0, 2.0, 3.0], # optimal point [2.0, 3.0, 4.0], - [0.5, 1.5, 2.5] + [0.5, 1.5, 2.5], ] - + for x in test_points # Test logpdf logpdf_out = zeros(1) ExponentialFamilyProjection.logpdf!(inplace_struct, logpdf_out, x) - expected_logpdf = -sum((x[i] - targets[i])^2 for i in 1:dim) + expected_logpdf = -sum((x[i] - targets[i])^2 for i = 1:dim) @test logpdf_out[1] ≈ expected_logpdf - + # Test gradient and hessian together grad_out, hess_out = inplace_struct.grad_hess!(zeros(dim), zeros(dim, dim), x) - expected_grad = [-2 * (x[i] - targets[i]) for i in 1:dim] + expected_grad = [-2 * (x[i] - targets[i]) for i = 1:dim] expected_hess = -2 * I(dim) @test grad_out ≈ expected_grad @test hess_out ≈ expected_hess @@ -183,21 +183,21 @@ end @testitem "InplaceLogpdfGradHess edge cases and validation" begin import ExponentialFamilyProjection: InplaceLogpdfGradHess - + # Test with different container sizes logpdf_fn! = (out, x) -> (out[1] = -(x[1] - 1)^2) grad_fn! = (out, x) -> (out[1] = -2 * (x[1] - 1)) hess_fn! = (out, x) -> (out[1, 1] = -2) - + inplace_struct = InplaceLogpdfGradHess(logpdf_fn!, grad_fn!, hess_fn!) - + x = [2.0] - + # Test that functions modify the containers correctly logpdf_out = ones(1) # start with non-zero values ExponentialFamilyProjection.logpdf!(inplace_struct, logpdf_out, x) @test logpdf_out[1] ≈ -1.0 # -(2-1)² = -1 - + grad_out = ones(1) hess_out = ones(1, 1) grad_out, hess_out = inplace_struct.grad_hess!(grad_out, hess_out, x) @@ -235,10 +235,10 @@ end μ = [1.0, 2.0] Σ = [2.0 0.5; 0.5 1.0] dist = MvNormalMeanCovariance(μ, Σ) - + # Create specific target function: -(x₁-1)² - (x₂-2)² target_logpdf = (x) -> -(x[1] - 1)^2 - (x[2] - 2)^2 - + # Create InplaceLogpdfGradHess for BonnetStrategy logpdf_fn! = (out, x) -> (out[1] = -(x[1] - 1)^2 - (x[2] - 2)^2) grad_fn! = (out, x) -> begin @@ -252,68 +252,70 @@ end out[2, 2] = -2 end bonnet_target = InplaceLogpdfGradHess(logpdf_fn!, grad_fn!, hess_fn!) - + # Test parameters nsamples = 80000 seed = 42 sample_dim = 2 - + # Create exponential family distribution and manifold ef = convert(ExponentialFamilyDistribution, dist) T = ExponentialFamily.exponential_family_typetag(ef) d = size(mean(ef)) c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) - + # Use the same initial point for both strategies rng = StableRNG(seed) initial_point = rand(rng, M) initial_ef = convert(ExponentialFamilyDistribution, M, initial_point) - + # Create strategies with same nsamples bonnet_strategy = BonnetStrategy(nsamples = nsamples) control_variate_strategy = ControlVariateStrategy(nsamples = nsamples, buffer = nothing) - + # Test with the same seed for reproducibility bonnet_parameters = ProjectionParameters(rng = StableRNG(seed), seed = seed) cv_parameters = ProjectionParameters(rng = StableRNG(seed), seed = seed) - + # Preprocess the strategy arguments to handle conversion properly import ExponentialFamilyProjection: preprocess_strategy_argument - bonnet_strategy_processed, bonnet_target_processed = preprocess_strategy_argument(bonnet_strategy, bonnet_target) - cv_strategy_processed, cv_target_processed = preprocess_strategy_argument(control_variate_strategy, target_logpdf) - + bonnet_strategy_processed, bonnet_target_processed = + preprocess_strategy_argument(bonnet_strategy, bonnet_target) + cv_strategy_processed, cv_target_processed = + preprocess_strategy_argument(control_variate_strategy, target_logpdf) + # Create containers for BonnetStrategy bonnet_samples = rand(initial_ef, nsamples) bonnet_logpdfs = zeros(nsamples) bonnet_grads = zeros(sample_dim, nsamples) bonnet_hessians = zeros(sample_dim, sample_dim, nsamples) bonnet_current_mean = zeros(sample_dim) - + bonnet_state = BonnetStrategyState( samples = bonnet_samples, logpdfs = bonnet_logpdfs, logbasemeasures = zeros(nsamples), grads = bonnet_grads, hessians = bonnet_hessians, - current_mean = bonnet_current_mean + current_mean = bonnet_current_mean, ) - + # Create containers for ControlVariateStrategy cv_samples = rand(initial_ef, nsamples) cv_logpdfs = zeros(nsamples) cv_logbasemeasures = zeros(nsamples) # We'll handle base measures cv_sufficientstatistics = zeros(length(getnaturalparameters(initial_ef)), nsamples) cv_gradsamples = zeros(length(getnaturalparameters(initial_ef)), nsamples) - + cv_state = ControlVariateStrategyState( samples = cv_samples, logpdfs = cv_logpdfs, logbasemeasures = cv_logbasemeasures, sufficientstatistics = cv_sufficientstatistics, - gradsamples = cv_gradsamples + gradsamples = cv_gradsamples, ) - + # Prepare states supplementary_η = () bonnet_state_prepared = prepare_state!( @@ -323,9 +325,9 @@ end bonnet_parameters, bonnet_target_processed, initial_ef, - supplementary_η + supplementary_η, ) - + cv_state_prepared = prepare_state!( cv_strategy_processed, cv_state, @@ -333,32 +335,32 @@ end cv_parameters, cv_target_processed, initial_ef, - supplementary_η + supplementary_η, ) - + # Verify both strategies are using the same samples (they should be with same seed) @test get_samples(bonnet_state_prepared) ≈ get_samples(cv_state_prepared) - + # Get some parameters from the initial distribution η = getnaturalparameters(initial_ef) logpartition = ExponentialFamily.logpartition(initial_ef) gradlogpartition = ExponentialFamily.gradlogpartition(initial_ef) fisherinformation = ExponentialFamily.fisherinformation(initial_ef) inv_fisher = inv(fisherinformation) - + # Create gradient containers bonnet_gradient = zeros(length(η)) cv_gradient = zeros(length(η)) - + # Compute gradients using both strategies compute_gradient!( M, bonnet_strategy_processed, bonnet_state_prepared, bonnet_gradient, - η + η, ) - + compute_gradient!( M, cv_strategy_processed, @@ -367,26 +369,27 @@ end η, logpartition, gradlogpartition, - inv_fisher + inv_fisher, ) - + # Compare the gradients - they should be approximately equal grad_diff = bonnet_gradient - cv_gradient @test dot(grad_diff, fisherinformation, grad_diff) ≈ 0 atol=1e-3 - + # Additional verification: test that both strategies produce finite results @test all(isfinite, bonnet_gradient) @test all(isfinite, cv_gradient) - + # Test with different initial points to ensure consistency for test_seed in [123, 456, 789] test_rng = StableRNG(test_seed) test_point = rand(test_rng, M) test_ef = convert(ExponentialFamilyDistribution, M, test_point) - - bonnet_params_test = ProjectionParameters(rng = StableRNG(test_seed), seed = test_seed) + + bonnet_params_test = + ProjectionParameters(rng = StableRNG(test_seed), seed = test_seed) cv_params_test = ProjectionParameters(rng = StableRNG(test_seed), seed = test_seed) - + # Prepare states with new test point prepare_state!( bonnet_strategy_processed, @@ -395,9 +398,9 @@ end bonnet_params_test, bonnet_target_processed, test_ef, - supplementary_η + supplementary_η, ) - + prepare_state!( cv_strategy_processed, cv_state, @@ -405,19 +408,19 @@ end cv_params_test, cv_target_processed, test_ef, - supplementary_η + supplementary_η, ) - + # Get parameters for test point test_η = getnaturalparameters(test_ef) test_logpartition = ExponentialFamily.logpartition(test_ef) test_gradlogpartition = ExponentialFamily.gradlogpartition(test_ef) test_inv_fisher = inv(ExponentialFamily.fisherinformation(test_ef)) - + # Reset gradient containers fill!(bonnet_gradient, 0.0) fill!(cv_gradient, 0.0) - + # Compute gradients compute_gradient!( M, @@ -426,7 +429,7 @@ end bonnet_gradient, test_η, ) - + compute_gradient!( M, cv_strategy_processed, @@ -435,9 +438,9 @@ end test_η, test_logpartition, test_gradlogpartition, - test_inv_fisher + test_inv_fisher, ) - + # Compare gradients for this test point grad_diff = bonnet_gradient - cv_gradient @test dot(grad_diff, fisherinformation, grad_diff) ≈ 0 atol=1e-3 @@ -467,23 +470,23 @@ end # Test BonnetStrategyState getters nsamples = 100 sample_dim = 3 - + samples = randn(sample_dim, nsamples) logpdfs = zeros(nsamples) logbasemeasures = zeros(nsamples) grads = randn(sample_dim, nsamples) hessians = randn(sample_dim, sample_dim, nsamples) current_mean = randn(sample_dim) - + state = BonnetStrategyState( samples = samples, logpdfs = logpdfs, logbasemeasures = logbasemeasures, grads = grads, hessians = hessians, - current_mean = current_mean + current_mean = current_mean, ) - + @test get_samples(state) === samples @test get_logpdfs(state) === logpdfs @test get_logbasemeasures(state) === logbasemeasures @@ -517,7 +520,7 @@ end μ = [1.0, 2.0] Σ = [2.0 0.5; 0.5 1.0] dist = MvNormalMeanCovariance(μ, Σ) - + # Create InplaceLogpdfGradHess manually logpdf_fn! = (out, x) -> (out[1] = -(x[1] - 1)^2 - (x[2] - 2)^2) grad_fn! = (out, x) -> begin @@ -531,10 +534,10 @@ end out[2, 2] = -2 end inplace_target = InplaceLogpdfGradHess(logpdf_fn!, grad_fn!, hess_fn!) - + nsamples = 50 sample_dim = 2 - + # Pre-create containers samples = rand(dist, nsamples) # This creates (2, nsamples) matrix logpdfs = zeros(nsamples) @@ -542,30 +545,38 @@ end grads = zeros(sample_dim, nsamples) # gradient at each sample hessians = zeros(sample_dim, sample_dim, nsamples) # hessian at each sample current_mean = zeros(sample_dim) - + state = BonnetStrategyState( samples = samples, logpdfs = logpdfs, logbasemeasures = logbasemeasures, grads = grads, hessians = hessians, - current_mean = current_mean + current_mean = current_mean, ) - + # Create exponential family distribution and parameters ef = convert(ExponentialFamilyDistribution, dist) T = ExponentialFamily.exponential_family_typetag(ef) d = size(mean(ef)) c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) - + strategy = BonnetStrategy(nsamples = nsamples) rng = StableRNG(42) parameters = ProjectionParameters(rng = rng) - + # Test prepare_state! (ignoring the current_mean computation for now due to dim_size issue) - prepare_state!(strategy, state, M, parameters, inplace_target, ef, getnaturalparameters(ef)) - + prepare_state!( + strategy, + state, + M, + parameters, + inplace_target, + ef, + getnaturalparameters(ef), + ) + # Test that the containers are filled correctly @test all(isfinite, get_logpdfs(state)) @test all(isfinite, get_grads(state)) @@ -579,7 +590,7 @@ end logpdf_out = zeros(1) ExponentialFamilyProjection.logpdf!(inplace_target, logpdf_out, sample) @test logpdf_out[1] ≈ get_logpdfs(state)[i] - + # Test gradient and hessian evaluation together grad_out, hess_out = inplace_target.grad_hess!(zeros(2), zeros(2, 2), sample) @test grad_out ≈ get_grads(state)[:, i] @@ -611,16 +622,16 @@ end # Test with univariate normal distribution dist = NormalMeanVariance(1.0, 2.0) - + # Create InplaceLogpdfGradHess manually logpdf_fn! = (out, x) -> (out[1] = -(x - 1)^2) grad_fn! = (out, x) -> (out[1] = -2 * (x - 1)) hess_fn! = (out, x) -> (out[1] = -2) inplace_target = InplaceLogpdfGradHess(logpdf_fn!, grad_fn!, hess_fn!) - + nsamples = 30 sample_dim = 1 - + # Pre-create containers for univariate case samples = rand(dist, nsamples) # This creates a vector of length nsamples logpdfs = zeros(nsamples) @@ -628,49 +639,57 @@ end grads = zeros(sample_dim, nsamples) # (1, nsamples) matrix hessians = zeros(sample_dim, sample_dim, nsamples) # (1, 1, nsamples) array current_mean = zeros(sample_dim) - + state = BonnetStrategyState( samples = samples, logpdfs = logpdfs, logbasemeasures = logbasemeasures, grads = grads, hessians = hessians, - current_mean = current_mean + current_mean = current_mean, ) - + # Create exponential family distribution and parameters ef = convert(ExponentialFamilyDistribution, dist) T = ExponentialFamily.exponential_family_typetag(ef) d = size(mean(ef)) c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) - + strategy = BonnetStrategy(nsamples = nsamples) rng = StableRNG(42) parameters = ProjectionParameters(rng = rng) - prepare_state!(strategy, state, M, parameters, inplace_target, ef, getnaturalparameters(ef)) - + prepare_state!( + strategy, + state, + M, + parameters, + inplace_target, + ef, + getnaturalparameters(ef), + ) + # Verify containers are filled correctly @test all(isfinite, get_logpdfs(state)) @test all(isfinite, get_grads(state)) @test all(isfinite, get_hessians(state)) - + _, sample_container = ExponentialFamily.check_logpdf(ef, get_samples(state)) - + # Manually evaluate for each sample for (i, sample) in enumerate(sample_container) # Test logpdf evaluation logpdf_out = zeros(1) ExponentialFamilyProjection.logpdf!(inplace_target, logpdf_out, sample) @test logpdf_out[1] ≈ get_logpdfs(state)[i] - + # Test gradient and hessian evaluation together grad_out, hess_out = inplace_target.grad_hess!(zeros(1), zeros(1, 1), sample) @test grad_out[1] ≈ get_grads(state)[1, i] @test hess_out[1, 1] ≈ get_hessians(state)[1, 1, i] end - + end @testitem "`BonnetStrategy` should fail if given a list of samples instead of a function" begin @@ -702,10 +721,10 @@ end rng = StableRNG(42) for distribution in [ - NormalMeanVariance(0, 1), - NormalMeanVariance(2, 1), - MvNormalMeanCovariance(ones(2), Matrix(Diagonal(ones(2)))), - ] + NormalMeanVariance(0, 1), + NormalMeanVariance(2, 1), + MvNormalMeanCovariance(ones(2), Matrix(Diagonal(ones(2)))), + ] # Create target distribution to project to target_dist = distribution @@ -715,32 +734,53 @@ end if distribution isa NormalMeanVariance # Univariate case logpdf_fn! = (out, x) -> (out[1] = logpdf(target_dist, x)) - grad_fn! = (out, x) -> (out[1] = ForwardDiff.derivative(x -> logpdf(target_dist, x), x)) - hess_fn! = (out, x) -> (out[1] = ForwardDiff.derivative(x -> ForwardDiff.derivative(x -> logpdf(target_dist, x), x), x)) + grad_fn! = + (out, x) -> + (out[1] = ForwardDiff.derivative(x -> logpdf(target_dist, x), x)) + hess_fn! = + (out, x) -> ( + out[1] = ForwardDiff.derivative( + x -> ForwardDiff.derivative(x -> logpdf(target_dist, x), x), + x, + ) + ) else # Multivariate case logpdf_fn! = (out, x) -> (out[1] = logpdf(target_dist, x)) - grad_fn! = (out, x) -> (out .= ForwardDiff.gradient(x -> logpdf(target_dist, x), x)) - hess_fn! = (out, x) -> (out .= ForwardDiff.hessian(x -> logpdf(target_dist, x), x)) + grad_fn! = + (out, x) -> (out .= ForwardDiff.gradient(x -> logpdf(target_dist, x), x)) + hess_fn! = + (out, x) -> (out .= ForwardDiff.hessian(x -> logpdf(target_dist, x), x)) end - - bonnet_target = ExponentialFamilyProjection.InplaceLogpdfGradHess(logpdf_fn!, grad_fn!, hess_fn!) + + bonnet_target = ExponentialFamilyProjection.InplaceLogpdfGradHess( + logpdf_fn!, + grad_fn!, + hess_fn!, + ) ef = convert(ExponentialFamilyDistribution, distribution) T = ExponentialFamily.exponential_family_typetag(ef) c = getconditioner(ef) d = size(rand(rng, ef)) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) - + # Test with both BonnetStrategy and ControlVariateStrategy for comparison bonnet_strategy = ExponentialFamilyProjection.BonnetStrategy(nsamples = 10000) cv_strategy = ExponentialFamilyProjection.ControlVariateStrategy(nsamples = 10000) - + p = ProjectionParameters(rng = StableRNG(42), seed = 42) η = getnaturalparameters(ef) # Test BonnetStrategy - bonnet_state = ExponentialFamilyProjection.create_state!(bonnet_strategy, M, p, bonnet_target, ef, ()) + bonnet_state = ExponentialFamilyProjection.create_state!( + bonnet_strategy, + M, + p, + bonnet_target, + ef, + (), + ) bonnet_obj = ExponentialFamilyProjection.ProjectionCostGradientObjective( p, bonnet_target, @@ -751,7 +791,8 @@ end ) # Test ControlVariateStrategy for comparison - cv_state = ExponentialFamilyProjection.create_state!(cv_strategy, M, p, targetfn, ef, ()) + cv_state = + ExponentialFamilyProjection.create_state!(cv_strategy, M, p, targetfn, ef, ()) cv_obj = ExponentialFamilyProjection.ProjectionCostGradientObjective( p, targetfn, @@ -768,7 +809,7 @@ end _gradlogpartition = gradlogpartition(ef) _fisher = fisherinformation(ef) _inv_fisher = inv(fisherinformation(ef)) - + # Compute costs bonnet_cost = ExponentialFamilyProjection.compute_cost( M, @@ -776,7 +817,7 @@ end bonnet_state, η, _gradlogpartition, - _logpartition + _logpartition, ) cv_cost = ExponentialFamilyProjection.compute_cost( @@ -798,7 +839,7 @@ end bonnet_strategy, bonnet_state, bonnet_gradient, - η + η, ) ExponentialFamilyProjection.compute_gradient!( @@ -821,13 +862,13 @@ end # Test gradient computation in manifold coordinates p_manifold = ExponentialFamilyManifolds.partition_point(M, η) - + X_p_bonnet = Manifolds.zero_vector(M, p_manifold) X_p_cv = Manifolds.zero_vector(M, p_manifold) - + c_p_bonnet, X_p_bonnet = bonnet_obj(M, X_p_bonnet, p_manifold) c_p_cv, X_p_cv = cv_obj(M, X_p_cv, p_manifold) - + @test c_p_bonnet ≈ bonnet_cost @test c_p_cv ≈ cv_cost X_diff = X_p_bonnet - X_p_cv @@ -869,7 +910,7 @@ end @testset "create_state! basic functionality - NormalMeanVariance" begin dist = NormalMeanVariance(0, 1) supplementary_η = () - + rng = StableRNG(42) ef = convert(ExponentialFamilyDistribution, dist) T = ExponentialFamily.exponential_family_typetag(ef) @@ -877,28 +918,33 @@ end c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) parameters = ProjectionParameters(; rng = rng) - + # Create logpdf function and its derivatives with proper ForwardDiff scoping logpdf_fn = (out, x) -> (out[1] = logpdf(ef, x)) - + # Univariate case - use derivative instead of gradient grad_fn = let ForwardDiff = ForwardDiff (out, x) -> (out[1] = ForwardDiff.derivative(x -> logpdf(ef, x), x)) end - hess_fn = let ForwardDiff = ForwardDiff - (out, x) -> (out[1] = ForwardDiff.derivative(x -> ForwardDiff.derivative(x -> logpdf(ef, x), x), x)) + hess_fn = let ForwardDiff = ForwardDiff + (out, x) -> ( + out[1] = ForwardDiff.derivative( + x -> ForwardDiff.derivative(x -> logpdf(ef, x), x), + x, + ) + ) end - + targetfn = InplaceLogpdfGradHess(logpdf_fn, grad_fn, hess_fn) strategy = BonnetStrategy(nsamples = 100) - + state1 = create_state!(strategy, M, parameters, targetfn, ef, supplementary_η) # Test manual state creation and preparation nsamples = get_nsamples(strategy) sample_dim = length(mean(ef)) param_dim = length(getnaturalparameters(ef)) - + # For univariate: samples is a vector, grads/hessians are in parameter space samples = zeros(paramfloattype(ef), nsamples) # Vector for univariate logpdfs = zeros(paramfloattype(ef), nsamples) @@ -906,7 +952,7 @@ end grads = zeros(paramfloattype(ef), sample_dim, nsamples) # 1 x nsamples hessians = zeros(paramfloattype(ef), sample_dim, sample_dim, nsamples) # 1 x 1 x nsamples current_mean = zeros(paramfloattype(ef), sample_dim) - + state2 = BonnetStrategyState( samples = samples, logpdfs = logpdfs, @@ -915,17 +961,10 @@ end hessians = hessians, current_mean = current_mean, ) - + strategy = BonnetStrategy(nsamples = nsamples) - state2_prepared = prepare_state!( - strategy, - state2, - M, - parameters, - targetfn, - ef, - supplementary_η, - ) + state2_prepared = + prepare_state!(strategy, state2, M, parameters, targetfn, ef, supplementary_η) @test get_samples(state1) == get_samples(state2) @test get_logpdfs(state1) == get_logpdfs(state2) @@ -944,23 +983,28 @@ end c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) parameters = ProjectionParameters(; rng = rng) - + # Create logpdf function and its derivatives with proper ForwardDiff scoping logpdf_fn = (out, x) -> (out[1] = logpdf(ef, x)) - + # Univariate case - use derivative instead of gradient grad_fn = let ForwardDiff = ForwardDiff (out, x) -> (out[1] = ForwardDiff.derivative(x -> logpdf(ef, x), x)) end hess_fn = let ForwardDiff = ForwardDiff - (out, x) -> (out[1] = ForwardDiff.derivative(x -> ForwardDiff.derivative(x -> logpdf(ef, x), x), x)) + (out, x) -> ( + out[1] = ForwardDiff.derivative( + x -> ForwardDiff.derivative(x -> logpdf(ef, x), x), + x, + ) + ) end - + targetfn = InplaceLogpdfGradHess(logpdf_fn, grad_fn, hess_fn) strategy = BonnetStrategy(nsamples = 10) - + state = create_state!(strategy, M, parameters, targetfn, ef, ()) - + @test length(get_samples(state)) == 10 # Vector for univariate @test length(get_logpdfs(state)) == 10 @test size(get_grads(state)) == (1, 10) # 1 x nsamples for univariate @@ -998,7 +1042,7 @@ end @testset "create_state! basic functionality - MvNormalMeanCovariance" begin dist = MvNormalMeanCovariance(ones(2), Matrix(Diagonal(ones(2)))) supplementary_η = () - + rng = StableRNG(42) ef = convert(ExponentialFamilyDistribution, dist) T = ExponentialFamily.exponential_family_typetag(ef) @@ -1006,36 +1050,37 @@ end c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) parameters = ProjectionParameters(; rng = rng) - + # Create logpdf function and its derivatives with proper ForwardDiff scoping logpdf_fn = (out, x) -> (out[1] = logpdf(ef, x)) - + # Multivariate case - use gradient! and hessian! grad_fn = let ForwardDiff = ForwardDiff (out, x) -> ForwardDiff.gradient!(out, x -> logpdf(ef, x), x) end - hess_fn = let ForwardDiff = ForwardDiff + hess_fn = let ForwardDiff = ForwardDiff (out, x) -> ForwardDiff.hessian!(out, x -> logpdf(ef, x), x) end - + targetfn = InplaceLogpdfGradHess(logpdf_fn, grad_fn, hess_fn) strategy = BonnetStrategy(nsamples = 100) - + state1 = create_state!(strategy, M, parameters, targetfn, ef, supplementary_η) # Test manual state creation and preparation nsamples = get_nsamples(strategy) sample_dim = length(mean(ef)) param_dim = length(getnaturalparameters(ef)) - + # For multivariate: samples is a matrix, grads/hessians are in sample space samples = zeros(paramfloattype(ef), sample_dim, nsamples) # Matrix for multivariate logpdfs = zeros(paramfloattype(ef), nsamples) - logbasemeasures = Fill(ExponentialFamily.logbasemeasure(ef, zeros(sample_dim)), nsamples) + logbasemeasures = + Fill(ExponentialFamily.logbasemeasure(ef, zeros(sample_dim)), nsamples) grads = zeros(paramfloattype(ef), sample_dim, nsamples) # sample_dim x nsamples hessians = zeros(paramfloattype(ef), sample_dim, sample_dim, nsamples) # sample_dim x sample_dim x nsamples current_mean = zeros(paramfloattype(ef), sample_dim) - + state2 = BonnetStrategyState( samples = samples, logpdfs = logpdfs, @@ -1044,17 +1089,10 @@ end hessians = hessians, current_mean = current_mean, ) - + strategy = BonnetStrategy(nsamples = nsamples) - state2_prepared = prepare_state!( - strategy, - state2, - M, - parameters, - targetfn, - ef, - supplementary_η, - ) + state2_prepared = + prepare_state!(strategy, state2, M, parameters, targetfn, ef, supplementary_η) @test get_samples(state1) == get_samples(state2) @test get_logpdfs(state1) == get_logpdfs(state2) @@ -1073,10 +1111,10 @@ end c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) parameters = ProjectionParameters(; rng = rng) - + # Create logpdf function and its derivatives with proper ForwardDiff scoping logpdf_fn = (out, x) -> (out[1] = logpdf(ef, x)) - + # Multivariate case - use gradient! and hessian! grad_fn = let ForwardDiff = ForwardDiff (out, x) -> ForwardDiff.gradient!(out, x -> logpdf(ef, x), x) @@ -1084,12 +1122,12 @@ end hess_fn = let ForwardDiff = ForwardDiff (out, x) -> ForwardDiff.hessian!(out, x -> logpdf(ef, x), x) end - + targetfn = InplaceLogpdfGradHess(logpdf_fn, grad_fn, hess_fn) strategy = BonnetStrategy(nsamples = 10) - + state = create_state!(strategy, M, parameters, targetfn, ef, ()) - + @test size(get_samples(state)) == (2, 10) # Matrix for multivariate @test length(get_logpdfs(state)) == 10 @test size(get_grads(state)) == (2, 10) # sample_dim x nsamples for multivariate @@ -1123,18 +1161,24 @@ end nsamples = 1000 println("\nPerformance comparison: BonnetStrategy vs ControlVariateStrategy") - println("Dimension | BonnetStrategy (μs) | ControlVariateStrategy (μs) | Speedup | Memory Ratio") - println("----------|---------------------|-----------------------------|---------|--------------") + println( + "Dimension | BonnetStrategy (μs) | ControlVariateStrategy (μs) | Speedup | Memory Ratio", + ) + println( + "----------|---------------------|-----------------------------|---------|--------------", + ) for dim in dimensions # Create high-dimensional normal distribution μ = randn(StableRNG(42), dim) - Σ = let A = randn(StableRNG(43), dim, dim); A * A' + 0.1 * I end + Σ = let A = randn(StableRNG(43), dim, dim); + A * A' + 0.1 * I + end dist = MvNormalMeanCovariance(μ, Σ) - + # Create target function for the same distribution target_logpdf = (x) -> logpdf(dist, x) - + # Create InplaceLogpdfGradHess for BonnetStrategy logpdf_fn = (out, x) -> (out[1] = logpdf(dist, x)) grad_fn = let ForwardDiff = ForwardDiff @@ -1151,76 +1195,118 @@ end d = size(mean(ef)) c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) - + η = getnaturalparameters(ef) - + # Pre-create variables for benchmarking test_parameters = ProjectionParameters(rng = StableRNG(42), seed = 42) p_manifold = ExponentialFamilyManifolds.partition_point(M, η) - + # Benchmark BonnetStrategy using ProjectionCostGradientObjective bonnet_benchmark = @benchmark begin bonnet_strategy = BonnetStrategy(nsamples = $nsamples) - bonnet_state = create_state!(bonnet_strategy, $M, $test_parameters, $bonnet_target, $ef, ()) + bonnet_state = create_state!( + bonnet_strategy, + $M, + $test_parameters, + $bonnet_target, + $ef, + (), + ) bonnet_obj = ProjectionCostGradientObjective( - $test_parameters, $bonnet_target, copy($η), (), bonnet_strategy, bonnet_state + $test_parameters, + $bonnet_target, + copy($η), + (), + bonnet_strategy, + bonnet_state, ) X_bonnet = Manifolds.zero_vector($M, $p_manifold) cost_bonnet, X_bonnet = bonnet_obj($M, X_bonnet, $p_manifold) end - + # Benchmark ControlVariateStrategy using ProjectionCostGradientObjective cv_benchmark = @benchmark begin cv_strategy = ControlVariateStrategy(nsamples = $nsamples, buffer = nothing) - cv_state = create_state!(cv_strategy, $M, $test_parameters, $target_logpdf, $ef, ()) + cv_state = create_state!( + cv_strategy, + $M, + $test_parameters, + $target_logpdf, + $ef, + (), + ) cv_obj = ProjectionCostGradientObjective( - $test_parameters, $target_logpdf, copy($η), (), cv_strategy, cv_state + $test_parameters, + $target_logpdf, + copy($η), + (), + cv_strategy, + cv_state, ) X_cv = Manifolds.zero_vector($M, $p_manifold) cost_cv, X_cv = cv_obj($M, X_cv, $p_manifold) end - + # Extract timing and memory statistics bonnet_time_μs = median(bonnet_benchmark.times) / 1000 # Convert ns to μs cv_time_μs = median(cv_benchmark.times) / 1000 # Convert ns to μs speedup = cv_time_μs / bonnet_time_μs - + bonnet_memory = bonnet_benchmark.memory cv_memory = cv_benchmark.memory memory_ratio = cv_memory / bonnet_memory - + # Print results - println(@sprintf("%9d | %19.1f | %27.1f | %6.2fx | %12.2fx", - dim, bonnet_time_μs, cv_time_μs, speedup, memory_ratio)) - + println( + @sprintf( + "%9d | %19.1f | %27.1f | %6.2fx | %12.2fx", + dim, + bonnet_time_μs, + cv_time_μs, + speedup, + memory_ratio + ) + ) + # Test that both strategies produce similar gradients (functional correctness) parameters = ProjectionParameters(rng = StableRNG(42), seed = 42) p_manifold = ExponentialFamilyManifolds.partition_point(M, η) - + bonnet_strategy = BonnetStrategy(nsamples = nsamples) bonnet_state = create_state!(bonnet_strategy, M, parameters, bonnet_target, ef, ()) bonnet_obj = ProjectionCostGradientObjective( - parameters, bonnet_target, copy(η), (), bonnet_strategy, bonnet_state + parameters, + bonnet_target, + copy(η), + (), + bonnet_strategy, + bonnet_state, ) X_bonnet = Manifolds.zero_vector(M, p_manifold) cost_bonnet, X_bonnet = bonnet_obj(M, X_bonnet, p_manifold) - + cv_strategy = ControlVariateStrategy(nsamples = nsamples, buffer = nothing) cv_state = create_state!(cv_strategy, M, parameters, target_logpdf, ef, ()) cv_obj = ProjectionCostGradientObjective( - parameters, target_logpdf, copy(η), (), cv_strategy, cv_state + parameters, + target_logpdf, + copy(η), + (), + cv_strategy, + cv_state, ) X_cv = Manifolds.zero_vector(M, p_manifold) cost_cv, X_cv = cv_obj(M, X_cv, p_manifold) - + # Verify gradients are approximately equal (using Fisher metric) fisherinformation = ExponentialFamily.fisherinformation(ef) X_diff = X_bonnet - X_cv fisher_norm_diff = sqrt(dot(X_diff, fisherinformation, X_diff)) - + @test_broken fisher_norm_diff < 0.1 # Gradients should be close in Fisher metric @test abs(cost_bonnet - cost_cv) < 0.1 # Costs should be similar - + # Test that BonnetStrategy shows performance advantage for higher dimensions if dim >= 20 @test_broken speedup > 1.0 # BonnetStrategy should be faster for high dimensions @@ -1229,27 +1315,42 @@ end end @testitem "BonnetStrategy vs ControlVariateStrategy performance comparison with analytical target" begin - using ExponentialFamily, Distributions, BayesBase, LinearAlgebra, Random, StableRNGs, ExponentialFamilyManifolds, Printf, BenchmarkTools, Manifolds - import ExponentialFamilyProjection: BonnetStrategy, ControlVariateStrategy, InplaceLogpdfGradHess, create_state!, ProjectionParameters, ProjectionCostGradientObjective + using ExponentialFamily, + Distributions, + BayesBase, + LinearAlgebra, + Random, + StableRNGs, + ExponentialFamilyManifolds, + Printf, + BenchmarkTools, + Manifolds + import ExponentialFamilyProjection: + BonnetStrategy, + ControlVariateStrategy, + InplaceLogpdfGradHess, + create_state!, + ProjectionParameters, + ProjectionCostGradientObjective # Simple analytical target: -||x - 1||^2 # logpdf(x) = -||x - 1||^2 # grad(x) = -2(x - 1) # hess(x) = -2I - + function analytical_logpdf!(out, x) - out[1] = -sum((x .- 1).^2) + out[1] = -sum((x .- 1) .^ 2) return out end - + function analytical_grad!(out, x) out .= -2 .* (x .- 1) return out end - + function analytical_hess!(out, x) fill!(out, 0.0) - for i in 1:size(out, 1) + for i = 1:size(out, 1) out[i, i] = -2.0 end return out @@ -1258,18 +1359,24 @@ end dimensions = [10, 20, 50] nsamples = 1000 - println("\nPerformance comparison with analytical target: BonnetStrategy vs ControlVariateStrategy") - println("Dimension | BonnetStrategy (μs) | ControlVariateStrategy (μs) | Speedup | Memory Ratio") - println("----------|---------------------|-----------------------------|---------|--------------") + println( + "\nPerformance comparison with analytical target: BonnetStrategy vs ControlVariateStrategy", + ) + println( + "Dimension | BonnetStrategy (μs) | ControlVariateStrategy (μs) | Speedup | Memory Ratio", + ) + println( + "----------|---------------------|-----------------------------|---------|--------------", + ) for dim in dimensions rng = StableRNG(42) - + # Create target distribution (we'll project to Normal) target_mean = ones(dim) target_cov = Matrix(I, dim, dim) dist = MvNormal(target_mean, target_cov) - + # Create exponential family distribution and manifold ef = convert(ExponentialFamilyDistribution, dist) T = ExponentialFamily.exponential_family_typetag(ef) @@ -1277,53 +1384,85 @@ end c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) η = getnaturalparameters(ef) - + # Create analytical target for BonnetStrategy - analytical_target = InplaceLogpdfGradHess(analytical_logpdf!, analytical_grad!, analytical_hess!) - + analytical_target = + InplaceLogpdfGradHess(analytical_logpdf!, analytical_grad!, analytical_hess!) + # Create simple logpdf function for ControlVariateStrategy - target_logpdf(x) = -sum((x .- 1).^2) - + target_logpdf(x) = -sum((x .- 1) .^ 2) + # Pre-create variables for benchmarking test_parameters = ProjectionParameters(rng = StableRNG(42), seed = 42) p_manifold = ExponentialFamilyManifolds.partition_point(M, η) - + # Benchmark BonnetStrategy using ProjectionCostGradientObjective bonnet_benchmark = @benchmark begin bonnet_strategy = BonnetStrategy(nsamples = $nsamples) - bonnet_state = create_state!(bonnet_strategy, $M, $test_parameters, $analytical_target, $ef, ()) + bonnet_state = create_state!( + bonnet_strategy, + $M, + $test_parameters, + $analytical_target, + $ef, + (), + ) bonnet_obj = ProjectionCostGradientObjective( - $test_parameters, $analytical_target, copy($η), (), bonnet_strategy, bonnet_state + $test_parameters, + $analytical_target, + copy($η), + (), + bonnet_strategy, + bonnet_state, ) X_bonnet = Manifolds.zero_vector($M, $p_manifold) cost_bonnet, X_bonnet = bonnet_obj($M, X_bonnet, $p_manifold) end - + # Benchmark ControlVariateStrategy using ProjectionCostGradientObjective cv_benchmark = @benchmark begin cv_strategy = ControlVariateStrategy(nsamples = $nsamples, buffer = nothing) - cv_state = create_state!(cv_strategy, $M, $test_parameters, $target_logpdf, $ef, ()) + cv_state = create_state!( + cv_strategy, + $M, + $test_parameters, + $target_logpdf, + $ef, + (), + ) cv_obj = ProjectionCostGradientObjective( - $test_parameters, $target_logpdf, copy($η), (), cv_strategy, cv_state + $test_parameters, + $target_logpdf, + copy($η), + (), + cv_strategy, + cv_state, ) X_cv = Manifolds.zero_vector($M, $p_manifold) cost_cv, X_cv = cv_obj($M, X_cv, $p_manifold) end - + # Extract timing and memory statistics bonnet_time_μs = median(bonnet_benchmark.times) / 1000 cv_time_μs = median(cv_benchmark.times) / 1000 speedup = cv_time_μs / bonnet_time_μs - + bonnet_memory = bonnet_benchmark.memory cv_memory = cv_benchmark.memory memory_ratio = cv_memory / bonnet_memory - - println(@sprintf("%9d | %19.1f | %27.1f | %6.2fx | %12.2fx |", - dim, bonnet_time_μs, cv_time_μs, speedup, memory_ratio)) + + println( + @sprintf( + "%9d | %19.1f | %27.1f | %6.2fx | %12.2fx |", + dim, + bonnet_time_μs, + cv_time_μs, + speedup, + memory_ratio + ) + ) @test speedup > 1.0 @test memory_ratio > 1.0 end end - diff --git a/test/strategies/control_variate_tests.jl b/test/strategies/control_variate_tests.jl index 7b56a52..b818177 100644 --- a/test/strategies/control_variate_tests.jl +++ b/test/strategies/control_variate_tests.jl @@ -131,7 +131,8 @@ end parameters = ProjectionParameters(rng = rng) strategy = ControlVariateStrategy(nsamples = nsamples) - @test_opt ignored_modules = (Base, LinearAlgebra, Distributions, ForwardDiff) create_state!( + @test_opt ignored_modules = + (Base, LinearAlgebra, Distributions, ForwardDiff) create_state!( strategy, M, parameters, @@ -140,31 +141,36 @@ end supplementary_η, ) - @test_opt ignored_modules = (Base, LinearAlgebra, Distributions, ForwardDiff) ExponentialFamilyProjection.prepare_samples_container( + @test_opt ignored_modules = + (Base, LinearAlgebra, Distributions, ForwardDiff) ExponentialFamilyProjection.prepare_samples_container( rng, ef, nsamples, supplementary_η, ) - @test_opt ignored_modules = (Base, LinearAlgebra, Distributions, ForwardDiff) ExponentialFamilyProjection.prepare_logpdfs_container( + @test_opt ignored_modules = + (Base, LinearAlgebra, Distributions, ForwardDiff) ExponentialFamilyProjection.prepare_logpdfs_container( rng, ef, nsamples, supplementary_η, ) - @test_opt ignored_modules = (Base, LinearAlgebra, Distributions, ForwardDiff) ExponentialFamilyProjection.prepare_logbasemeasures_container( + @test_opt ignored_modules = + (Base, LinearAlgebra, Distributions, ForwardDiff) ExponentialFamilyProjection.prepare_logbasemeasures_container( rng, ef, nsamples, supplementary_η, ) - @test_opt ignored_modules = (Base, LinearAlgebra, Distributions, ForwardDiff) ExponentialFamilyProjection.prepare_sufficientstatistics_container( + @test_opt ignored_modules = + (Base, LinearAlgebra, Distributions, ForwardDiff) ExponentialFamilyProjection.prepare_sufficientstatistics_container( rng, ef, nsamples, supplementary_η, ) - @test_opt ignored_modules = (Base, LinearAlgebra, Distributions, ForwardDiff) ExponentialFamilyProjection.prepare_gradsamples_container( + @test_opt ignored_modules = + (Base, LinearAlgebra, Distributions, ForwardDiff) ExponentialFamilyProjection.prepare_gradsamples_container( rng, ef, nsamples, @@ -459,4 +465,4 @@ end # Small differences are allowed due to different LinearAlgebra routines @test result_with_buffer ≈ result_without_buffer end -end \ No newline at end of file +end diff --git a/test/strategies/default_tests.jl b/test/strategies/default_tests.jl index 3c5ce49..33c58c6 100644 --- a/test/strategies/default_tests.jl +++ b/test/strategies/default_tests.jl @@ -30,4 +30,4 @@ @test result_1 ≈ distribution atol = 1e-1 @test result_2 ≈ distribution atol = 1e-1 end -end \ No newline at end of file +end diff --git a/test/strategies/gauss_newton_tests.jl b/test/strategies/gauss_newton_tests.jl index d555734..9a8cbe6 100644 --- a/test/strategies/gauss_newton_tests.jl +++ b/test/strategies/gauss_newton_tests.jl @@ -183,13 +183,19 @@ end nsamples = 1000 println("\nPerformance comparison: GaussNewton vs ControlVariateStrategy") - println("Dimension | GaussNewton (μs) | ControlVariateStrategy (μs) | Speedup | Memory Ratio") - println("----------|-------------------|-----------------------------|---------|--------------") + println( + "Dimension | GaussNewton (μs) | ControlVariateStrategy (μs) | Speedup | Memory Ratio", + ) + println( + "----------|-------------------|-----------------------------|---------|--------------", + ) for dim in dimensions # Create high-dimensional normal distribution μ = randn(StableRNG(42), dim) - Σ = let A = randn(StableRNG(43), dim, dim); A * A' + 0.1 * I end + Σ = let A = randn(StableRNG(43), dim, dim); + A * A' + 0.1 * I + end dist = MvNormalMeanCovariance(μ, Σ) # Create target function for the same distribution @@ -221,9 +227,15 @@ end # Benchmark GaussNewton using ProjectionCostGradientObjective gn_benchmark = @benchmark begin gn_strategy = GaussNewton(nsamples = $nsamples) - gn_state = create_state!(gn_strategy, $M, $test_parameters, $gn_target, $ef, ()) + gn_state = + create_state!(gn_strategy, $M, $test_parameters, $gn_target, $ef, ()) gn_obj = ProjectionCostGradientObjective( - $test_parameters, $gn_target, copy($η), (), gn_strategy, gn_state + $test_parameters, + $gn_target, + copy($η), + (), + gn_strategy, + gn_state, ) X_gn = Manifolds.zero_vector($M, $p_manifold) cost_gn, X_gn = gn_obj($M, X_gn, $p_manifold) @@ -232,9 +244,21 @@ end # Benchmark ControlVariateStrategy using ProjectionCostGradientObjective cv_benchmark = @benchmark begin cv_strategy = ControlVariateStrategy(nsamples = $nsamples, buffer = nothing) - cv_state = create_state!(cv_strategy, $M, $test_parameters, $target_logpdf, $ef, ()) + cv_state = create_state!( + cv_strategy, + $M, + $test_parameters, + $target_logpdf, + $ef, + (), + ) cv_obj = ProjectionCostGradientObjective( - $test_parameters, $target_logpdf, copy($η), (), cv_strategy, cv_state + $test_parameters, + $target_logpdf, + copy($η), + (), + cv_strategy, + cv_state, ) X_cv = Manifolds.zero_vector($M, $p_manifold) cost_cv, X_cv = cv_obj($M, X_cv, $p_manifold) @@ -250,8 +274,16 @@ end memory_ratio = cv_memory / gn_memory # Print results - println(@sprintf("%9d | %17.1f | %27.1f | %6.2fx | %12.2fx", - dim, gn_time_μs, cv_time_μs, speedup, memory_ratio)) + println( + @sprintf( + "%9d | %17.1f | %27.1f | %6.2fx | %12.2fx", + dim, + gn_time_μs, + cv_time_μs, + speedup, + memory_ratio + ) + ) # Test that both strategies produce similar outputs (functional correctness) parameters = ProjectionParameters(rng = StableRNG(42), seed = 42) @@ -260,7 +292,12 @@ end gn_strategy = GaussNewton(nsamples = nsamples) gn_state = create_state!(gn_strategy, M, parameters, gn_target, ef, ()) gn_obj = ProjectionCostGradientObjective( - parameters, gn_target, copy(η), (), gn_strategy, gn_state + parameters, + gn_target, + copy(η), + (), + gn_strategy, + gn_state, ) X_gn = Manifolds.zero_vector(M, p_manifold) cost_gn, X_gn = gn_obj(M, X_gn, p_manifold) @@ -268,7 +305,12 @@ end cv_strategy = ControlVariateStrategy(nsamples = nsamples, buffer = nothing) cv_state = create_state!(cv_strategy, M, parameters, target_logpdf, ef, ()) cv_obj = ProjectionCostGradientObjective( - parameters, target_logpdf, copy(η), (), cv_strategy, cv_state + parameters, + target_logpdf, + copy(η), + (), + cv_strategy, + cv_state, ) X_cv = Manifolds.zero_vector(M, p_manifold) cost_cv, X_cv = cv_obj(M, X_cv, p_manifold) @@ -279,5 +321,3 @@ end end end end - - diff --git a/test/strategies/mle_tests.jl b/test/strategies/mle_tests.jl index 61e112f..86f85c1 100644 --- a/test/strategies/mle_tests.jl +++ b/test/strategies/mle_tests.jl @@ -37,7 +37,9 @@ end Poisson(0.5), Chisq(30.0), Gamma(1, 1), - ],nsamples in (100, 500) + ], + nsamples in (100, 500) + ef = convert(ExponentialFamilyDistribution, distribution) samples = rand(rng, ef, nsamples) @@ -49,7 +51,14 @@ end η = getnaturalparameters(ef) strategy = ExponentialFamilyProjection.MLEStrategy() - state = ExponentialFamilyProjection.create_state!(strategy, M, proj_params, samples, ef, ()) + state = ExponentialFamilyProjection.create_state!( + strategy, + M, + proj_params, + samples, + ef, + (), + ) obj = ExponentialFamilyProjection.ProjectionCostGradientObjective( proj_params, samples, @@ -113,4 +122,4 @@ end @test X_p ≈ _inv_fisher * expected_gradient_p end end -end \ No newline at end of file +end