Skip to content

Commit 7fc526d

Browse files
authored
Merge pull request #469 from ReactiveBayes/fix-tests
Fix tests for new Fast Cholesky
2 parents d5de6df + 6b14e74 commit 7fc526d

59 files changed

Lines changed: 452 additions & 221 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/FormatCheck.yml

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
name: Format Check
2+
on:
3+
push:
4+
branches:
5+
- main
6+
tags: ['*']
7+
pull_request:
8+
workflow_dispatch:
9+
concurrency:
10+
# Skip intermediate builds: always.
11+
# Cancel intermediate builds: only if it is a pull request build.
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
14+
jobs:
15+
format-check:
16+
name: Code Format Check
17+
runs-on: ubuntu-latest
18+
# Don't run on PRs that come from forks as they won't have permission to create PRs
19+
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
20+
permissions:
21+
contents: write # Needed to push commits
22+
issues: write # Needed to create PRs and write comments
23+
pull-requests: write # Needed to create PRs and write comments
24+
steps:
25+
- uses: actions/checkout@v4
26+
with:
27+
ref: ${{ github.head_ref }}
28+
token: ${{ secrets.GITHUB_TOKEN }}
29+
- uses: julia-actions/setup-julia@v2
30+
- uses: julia-actions/cache@v2
31+
32+
# Find existing format PR if any
33+
- name: Find existing format PR
34+
id: find_pr
35+
uses: actions/github-script@v7
36+
with:
37+
github-token: ${{ secrets.GITHUB_TOKEN }}
38+
script: |
39+
const prNumber = ${{ github.event.pull_request.number }};
40+
const owner = context.repo.owner;
41+
const repo = context.repo.repo;
42+
43+
// Look for open PRs with our auto-format branch pattern that targets this PR's branch
44+
const prs = await github.rest.pulls.list({
45+
owner,
46+
repo,
47+
state: 'open',
48+
base: '${{ github.head_ref }}'
49+
});
50+
51+
const formatPr = prs.data.find(pr => pr.head.ref.startsWith('auto-format-') &&
52+
pr.title === "🤖 Auto-format Julia code");
53+
54+
if (formatPr) {
55+
console.log(`Found existing format PR: #${formatPr.number}`);
56+
return formatPr.number;
57+
}
58+
59+
return '';
60+
61+
- name: Run formatter check
62+
id: format_check
63+
run: |
64+
if ! make check-format; then
65+
echo "format_needs_fix=true" >> $GITHUB_OUTPUT
66+
else
67+
echo "format_needs_fix=false" >> $GITHUB_OUTPUT
68+
fi
69+
70+
# Close any existing formatting PR if the check now passes
71+
- name: Close existing format PR if check passes
72+
if: steps.format_check.outputs.format_needs_fix == 'false' && steps.find_pr.outputs.result != ''
73+
uses: actions/github-script@v7
74+
with:
75+
github-token: ${{ secrets.GITHUB_TOKEN }}
76+
script: |
77+
const formatPrNumber = Number(${{ steps.find_pr.outputs.result }});
78+
79+
if (formatPrNumber === 0) {
80+
return;
81+
}
82+
83+
const owner = context.repo.owner;
84+
const repo = context.repo.repo;
85+
86+
// Close the PR with a comment
87+
await github.rest.issues.createComment({
88+
owner,
89+
repo,
90+
issue_number: formatPrNumber,
91+
body: `Closing this PR as the code formatting issues in the original PR have been resolved.`
92+
});
93+
94+
await github.rest.pulls.update({
95+
owner,
96+
repo,
97+
pull_number: formatPrNumber,
98+
state: 'closed'
99+
});
100+
101+
console.log(`Closed format PR #${formatPrNumber} as the original PR now passes formatting checks.`);
102+
103+
- name: Apply formatter if needed
104+
if: steps.format_check.outputs.format_needs_fix == 'true'
105+
run: |
106+
make format
107+
108+
- name: Commit changes and create/update PR
109+
if: steps.format_check.outputs.format_needs_fix == 'true'
110+
uses: peter-evans/create-pull-request@v7
111+
with:
112+
token: ${{ secrets.GITHUB_TOKEN }}
113+
commit-message: "🤖 Auto-format Julia code"
114+
title: "🤖 Auto-format Julia code"
115+
body: |
116+
This PR was automatically created to fix Julia code formatting issues.
117+
118+
The formatting was applied using JuliaFormatter according to the project's style guidelines.
119+
120+
Please review the changes and merge if appropriate.
121+
branch: auto-format-${{ github.event.pull_request.number }}
122+
base: ${{ github.head_ref }}
123+
delete-branch: true
124+
labels: |
125+
automated pr
126+
code style
127+
id: create-pr
128+
129+
- name: Comment on original PR
130+
if: steps.format_check.outputs.format_needs_fix == 'true' && steps.create-pr.outputs.pull-request-number && steps.find_pr.outputs.result == ''
131+
uses: actions/github-script@v7
132+
with:
133+
github-token: ${{ secrets.GITHUB_TOKEN }}
134+
script: |
135+
const prNumber = ${{ github.event.pull_request.number }};
136+
const formatPrNumber = ${{ steps.create-pr.outputs.pull-request-number }};
137+
const formatPrUrl = `https://github.com/${{ github.repository }}/pull/${formatPrNumber}`;
138+
139+
await github.rest.issues.createComment({
140+
owner: context.repo.owner,
141+
repo: context.repo.repo,
142+
issue_number: prNumber,
143+
body: `## 🤖 Code Formatting
144+
145+
This PR has some code formatting issues. I've created [PR #${formatPrNumber}](${formatPrUrl}) with the necessary formatting changes.
146+
147+
You can merge that PR into this branch to fix the code style check.
148+
149+
Alternatively, you can run \`make format\` locally and push the changes yourself.`
150+
});
151+
152+
- name: Comment on original PR for updated formatting PR
153+
if: steps.format_check.outputs.format_needs_fix == 'true' && steps.create-pr.outputs.pull-request-number && steps.find_pr.outputs.result != ''
154+
uses: actions/github-script@v7
155+
with:
156+
github-token: ${{ secrets.GITHUB_TOKEN }}
157+
script: |
158+
const prNumber = ${{ github.event.pull_request.number }};
159+
const formatPrNumber = ${{ steps.create-pr.outputs.pull-request-number }};
160+
const formatPrUrl = `https://github.com/${{ github.repository }}/pull/${formatPrNumber}`;
161+
162+
await github.rest.issues.createComment({
163+
owner: context.repo.owner,
164+
repo: context.repo.repo,
165+
issue_number: prNumber,
166+
body: `## 🤖 Code Formatting
167+
168+
Your PR still has some code formatting issues. I've updated [PR #${formatPrNumber}](${formatPrUrl}) with the necessary formatting changes.
169+
170+
You can merge that PR into this branch to fix the code style check.
171+
172+
Alternatively, you can run \`make format\` locally and push the changes yourself.`
173+
});
174+
175+
# Fail the job if formatting was needed and applied
176+
- name: Fail if formatting was needed
177+
if: steps.format_check.outputs.format_needs_fix == 'true'
178+
run: |
179+
echo "::error::Code formatting issues detected. A PR with fixes has been created, but this check is failing to indicate that formatting needs to be fixed."
180+
exit 1

.github/workflows/ci.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ on:
77
push:
88
branches:
99
- 'main'
10-
tags: '*'
10+
tags: ['*']
1111
check_run:
1212
types: [rerequested]
1313
schedule:
@@ -18,11 +18,6 @@ permissions:
1818
contents: read
1919

2020
jobs:
21-
code-style:
22-
name: Code Style Suggestions
23-
runs-on: ubuntu-latest
24-
steps:
25-
- uses: julia-actions/julia-format@v3
2621
test:
2722
name: Tests ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
2823
runs-on: ${{ matrix.os }}

Makefile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ SHELL = /bin/bash
66
scripts_init:
77
julia --project=scripts/ -e 'using Pkg; Pkg.instantiate(); Pkg.update(); Pkg.precompile();'
88

9-
lint: scripts_init ## Code formating check
10-
julia --project=scripts/ scripts/format.jl
9+
format: scripts_init ## Format Julia code
10+
julia --project=scripts/ scripts/formatter.jl --overwrite
1111

12-
format: scripts_init ## Code formating run
13-
julia --project=scripts/ scripts/format.jl --overwrite
12+
check-format: scripts_init ## Check Julia code formatting (does not modify files)
13+
julia --project=scripts/ scripts/formatter.jl
1414

1515
.PHONY: benchmark
1616

scripts/format.jl

Lines changed: 0 additions & 25 deletions
This file was deleted.

scripts/formatter.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using JuliaFormatter
2+
using ArgParse
3+
4+
s = ArgParseSettings()
5+
6+
@add_arg_table s begin
7+
"--overwrite"
8+
help = "Overwrite the files with the formatted code"
9+
action = :store_true
10+
default = false
11+
end
12+
13+
args = parse_args(ARGS, s)
14+
overwrite = args["overwrite"]
15+
projectroot = joinpath(@__DIR__, "..")
16+
17+
passed = format(projectroot, verbose = true, overwrite = overwrite)
18+
19+
if !passed && !overwrite
20+
@error "JuliaFormatter check has failed. Run `make format` from the main directory and commit your changes to fix code style."
21+
exit(1)
22+
elseif !passed && overwrite
23+
@info "JuliaFormatter has overwritten files according to style guidelines"
24+
elseif passed
25+
@info "Codestyle from JuliaFormatted checks have passed"
26+
end

src/approximations/cvi.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,11 @@ function prod(approximation::CVI, outbound, inbound)
202202

203203
# Initial parameters of projected distribution
204204
current_ef = convert(ExponentialFamilyDistribution, inbound) # current EF distribution
205-
current_λ = getnaturalparameters(current_ef) # current natural parameters
205+
current_λ = getnaturalparameters(current_ef) # current natural parameters
206206
scontainer = rand(rng, sampling_optimized(inbound), approximation.n_gradpoints) # sampling container
207-
current_∇ = similar(current_λ) # current gradient
208-
new_λ = similar(current_λ) # new natural parameters
209-
cache = similar(current_λ) # just intermediate buffer
207+
current_∇ = similar(current_λ) # current gradient
208+
new_λ = similar(current_λ) # new natural parameters
209+
cache = similar(current_λ) # just intermediate buffer
210210

211211
# We avoid use of lambda functions, because they cannot capture `T`
212212
# which leads to performance issues

src/approximations/optimizers/adam.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ mutable struct Adam{T} <: Optimizer
66
r :: T
77
shat :: T
88
rhat :: T
9-
ρ1 :: Float64
10-
ρ2 :: Float64
11-
λ :: Float64
9+
ρ1 :: Float64
10+
ρ2 :: Float64
11+
λ :: Float64
1212
it :: Int64
1313
tmp :: T
1414
end
@@ -26,9 +26,9 @@ gets(optimizer::Adam) = return optimizer.s
2626
getr(optimizer::Adam) = return optimizer.r
2727
getshat(optimizer::Adam) = return optimizer.shat
2828
getrhat(optimizer::Adam) = return optimizer.rhat
29-
getρ1(optimizer::Adam) = return optimizer.ρ1
30-
getρ2(optimizer::Adam) = return optimizer.ρ2
31-
getλ(optimizer::Adam) = return optimizer.λ
29+
getρ1(optimizer::Adam) = return optimizer.ρ1
30+
getρ2(optimizer::Adam) = return optimizer.ρ2
31+
getλ(optimizer::Adam) = return optimizer.λ
3232
getit(optimizer::Adam) = return optimizer.it
3333
getall(optimizer::Adam) = return optimizer.x, optimizer.s, optimizer.r, optimizer.shat, optimizer.rhat, optimizer.ρ1, optimizer.ρ2, optimizer.λ, optimizer.it, optimizer.tmp
3434

src/approximations/unscented.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ const default_beta = 2.0
55
const default_kappa = 0.0
66

77
struct UnscentedExtra{T, R, M, C}
8-
L :: T
9-
λ :: R
10-
Wm :: M
11-
Wc :: C
8+
L::T
9+
λ::R
10+
Wm::M
11+
Wc::C
1212
end
1313

1414
"""
@@ -72,13 +72,13 @@ getκ(approximation::Unscented) = approximation.κ
7272

7373
getextra(approximation::Unscented) = approximation.e
7474

75-
getL(approximation::Unscented) = getL(getextra(approximation))
76-
getλ(approximation::Unscented) = getλ(getextra(approximation))
75+
getL(approximation::Unscented) = getL(getextra(approximation))
76+
getλ(approximation::Unscented) = getλ(getextra(approximation))
7777
getWm(approximation::Unscented) = getWm(getextra(approximation))
7878
getWc(approximation::Unscented) = getWc(getextra(approximation))
7979

80-
getL(extra::UnscentedExtra) = extra.L
81-
getλ(extra::UnscentedExtra) = extra.λ
80+
getL(extra::UnscentedExtra) = extra.L
81+
getλ(extra::UnscentedExtra) = extra.λ
8282
getWm(extra::UnscentedExtra) = extra.Wm
8383
getWc(extra::UnscentedExtra) = extra.Wc
8484

src/helpers/macrohelpers.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,16 @@ __test_inferred_typeof(x) = typeof(x)
105105
__test_inferred_typeof(::Type{T}) where {T} = Type{T}
106106

107107
macro test_inferred(T, expression)
108-
return esc(
109-
quote
110-
let
111-
local result = Test.@inferred($expression)
112-
if !(ReactiveMP.MacroHelpers.__test_inferred_typeof(result) <: $T)
113-
error("Result type $(ReactiveMP.MacroHelpers.__test_inferred_typeof(result)) does not match allowed type $T")
114-
end
115-
@test ReactiveMP.MacroHelpers.__test_inferred_typeof(result) <: $T
116-
result
108+
return esc(quote
109+
let
110+
local result = Test.@inferred($expression)
111+
if !(ReactiveMP.MacroHelpers.__test_inferred_typeof(result) <: $T)
112+
error("Result type $(ReactiveMP.MacroHelpers.__test_inferred_typeof(result)) does not match allowed type $T")
117113
end
114+
@test ReactiveMP.MacroHelpers.__test_inferred_typeof(result) <: $T
115+
result
118116
end
119-
)
117+
end)
120118
end
121119

122120
end

src/message.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,9 @@ function multiply_messages(prod_strategy, left::Message, right::Message)
136136
return Message(new_dist, is_prod_clamped, is_prod_initial, new_addons)
137137
end
138138

139-
constrain_form_as_message(message::Message, form_constraint) =
140-
Message(constrain_form(form_constraint, getdata(message)), is_clamped(message), is_initial(message), getaddons(message))
139+
constrain_form_as_message(message::Message, form_constraint) = Message(
140+
constrain_form(form_constraint, getdata(message)), is_clamped(message), is_initial(message), getaddons(message)
141+
)
141142

142143
# Note: we need extra Base.Generator(as_message, messages) step here, because some of the messages might be VMP messages
143144
# We want to cast it explicitly to a Message structure (which as_message does in case of DeferredMessage)

0 commit comments

Comments
 (0)