Skip to content

Commit 6d9e3b2

Browse files
committed
add type hint for julia 1.10
to make loss_gf inferred
1 parent fbb0d3d commit 6d9e3b2

3 files changed

Lines changed: 10 additions & 6 deletions

File tree

src/elbo.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,13 @@ function compute_priors_logdensity(priorsP, priorsM, θP, θMs, zero_prior_logde
246246
#TRET = Base.return_types(logpdf_tv_sum, Tuple{typeof(priorsM[i]), typeof(θMs[:,i])})
247247
end
248248
end
249-
nlMs_sum = sum(f_col, 1:length(priorsM))
249+
nlMs_sum = sum(f_col, 1:length(priorsM))::typeof(nlP0) # not type inferred in julia 1.10
250250
neg_log_prior_i = nlP0 - nlMs_sum
251251
if !isfinite(neg_log_prior_i)
252252
@show neg_log_prior_i, nlP0
253253
@show θMs
254254
@show priorsM
255255
error("inspect non-finite priors")
256-
i_par = 2
257-
priorMi, θMi = priorsM[i_par], eachcol(θMs)[2]
258256
end
259257
neg_log_prior_i
260258
end

src/gf.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,10 @@ function get_loss_gf(g, transM, transP, f, py,
277277
logpdf_tv = (prior, θ::AbstractVector) -> begin
278278
map(Base.Fix1(logpdf, prior), θ)::Vector{eltype(θP_pred)}
279279
end
280-
neg_log_prior = compute_priors_logdensity(priorsP, priorsM, θP_pred, θMs_pred,
281-
is_omit_priors, zero_prior_logdensity)
280+
neg_log_prior =
281+
# @descend_code_warntype (
282+
compute_priors_logdensity(priorsP, priorsM, θP_pred, θMs_pred,
283+
is_omit_priors, zero_prior_logdensity)
282284
if !isfinite(neg_log_prior)
283285
@info "loss_gf: encountered non-finite prior density"
284286
@show θP_pred, θMs_pred, ϕc.ϕP

test/test_HybridProblem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,12 @@ test_without_flux = (scenario) -> begin
171171
(_xM, _xP, _y_o, _y_unc, _i_sites) = first(train_loader)
172172
#l1 = loss_gf(p0, _xM, _xP, _y_o, _y_unc, _i_sites; is_testmode = false)
173173

174+
#using ShareAdd
175+
#@usingany Cthulhu
174176
l1 = @inferred (
175177
# @descend_code_warntype (
176-
loss_gf(p0, _xM, _xP, _y_o, _y_unc, _i_sites; is_testmode = true))
178+
loss_gf(p0, _xM, _xP, _y_o, _y_unc, _i_sites; is_testmode = true)
179+
)
177180
tld = first(train_loader)
178181
gr = Zygote.gradient(p -> loss_gf(p, tld...; is_testmode = false)[1], CA.getdata(p0))
179182
@test gr[1] isa Vector
@@ -198,6 +201,7 @@ test_without_flux = (scenario) -> begin
198201
end
199202

200203
#test_without_flux(Val((:MeanHVIApproximation,))) # not used in loss_gf
204+
#scenario=Val((:default,))
201205
test_without_flux(Val((:default,)))
202206
test_without_flux(Val((:covarK2,)))
203207

0 commit comments

Comments
 (0)