Skip to content

Commit 7396423

Browse files
authored
Merge pull request #80 from CosmologicalEmulators/develop
Update ci.yml
2 parents 6583408 + fa7a474 commit 7396423

5 files changed

Lines changed: 139 additions & 95 deletions

File tree

.github/workflows/ci.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ on:
1212
tags: "*"
1313
jobs:
1414
test:
15-
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
15+
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }}
1616
runs-on: ${{ matrix.os }}
1717
strategy:
1818
fail-fast: false
@@ -42,11 +42,15 @@ jobs:
4242
${{ runner.os }}-
4343
- uses: julia-actions/julia-buildpkg@v1
4444
- uses: julia-actions/julia-runtest@v1
45+
# Only process and upload coverage once, on a single representative job
4546
- uses: julia-actions/julia-processcoverage@v1
47+
if: matrix.version == '1.10' && matrix.os == 'ubuntu-latest'
4648
with:
4749
directories: src,ext
4850
- uses: codecov/codecov-action@v5
51+
if: matrix.version == '1.10' && matrix.os == 'ubuntu-latest'
4952
with:
5053
files: lcov.info
5154
token: ${{ secrets.CODECOV_TOKEN }}
5255
fail_ci_if_error: false
56+
verbose: true

Project.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,42 @@
11
name = "AbstractCosmologicalEmulators"
22
uuid = "c83c1981-e5c4-4837-9eb8-c9b1572acfc6"
3-
version = "0.9.3"
3+
version = "0.9.4"
44
authors = ["Marco Bonici <bonici.marco@gmail.com>"]
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
9+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
810
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
11+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
912
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1013
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1114
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1215
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1316
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
17+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1418
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
1519
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
20+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1621

1722
[weakdeps]
1823
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
1924
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
2025
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
21-
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2226
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
2327
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2428

2529
[extensions]
2630
BackgroundCosmologyExt = ["DataInterpolations", "FastGaussQuadrature", "Integrals", "LinearAlgebra", "OrdinaryDiffEqTsit5", "SciMLSensitivity"]
27-
MooncakeExt = ["Mooncake"]
31+
MooncakeExt = ["ForwardDiff", "Mooncake"]
2832

2933
[compat]
34+
ADTypes = "1.21.0"
3035
Adapt = "3, 4"
36+
BenchmarkTools = "1.6.3"
3137
ChainRulesCore = "1.26"
3238
DataInterpolations = "6, 8"
39+
DifferentiationInterface = "0.7.16"
3340
FFTW = "1"
3441
FastGaussQuadrature = "1"
3542
ForwardDiff = "0.10, 1"
@@ -42,6 +49,7 @@ NPZ = "0.4"
4249
OrdinaryDiffEqTsit5 = "1"
4350
SciMLSensitivity = "7.90"
4451
SimpleChains = "0.4"
52+
Zygote = "0.7.10"
4553
julia = "1.10"
4654

4755
[extras]

ext/MooncakeExt/MooncakeExt.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,39 @@ Mooncake.increment_rdata!!(x::FFTW.FFTWPlan, ::NoRData) = x
4848

4949
@from_chainrules MinimalCtx Tuple{typeof(AbstractCosmologicalEmulators.chebyshev_decomposition), Any, Any}
5050

51+
52+
using ForwardDiff
53+
using Lux
54+
using ChainRulesCore
55+
56+
# Define ChainRulesCore.rrule for run_emulator (specifically for LuxEmulator)
57+
function ChainRulesCore.rrule(::typeof(AbstractCosmologicalEmulators.run_emulator), input, emulator::AbstractCosmologicalEmulators.LuxEmulator)
58+
y = AbstractCosmologicalEmulators.run_emulator(input, emulator)
59+
60+
function run_emulator_pullback(Δy)
61+
if Δy isa ChainRulesCore.AbstractZero
62+
return NoTangent(), ZeroTangent(), NoTangent()
63+
end
64+
65+
# Ensure Δy is a dense vector
66+
Δy_vec = collect(vec(ChainRulesCore.unthunk(Δy)))
67+
68+
# ForwardDiff VJP
69+
vjp_input = convert(typeof(input), ForwardDiff.gradient(
70+
x -> begin
71+
y_dual, _ = Lux.apply(emulator.Model, x, emulator.Parameters, emulator.States)
72+
sum(y_dual .* Δy_vec)
73+
end,
74+
input
75+
))
76+
77+
return NoTangent(), vjp_input, NoTangent()
78+
end
79+
80+
return y, run_emulator_pullback
81+
end
82+
83+
# Register it for Mooncake
84+
Mooncake.@from_chainrules Mooncake.MinimalCtx Tuple{typeof(AbstractCosmologicalEmulators.run_emulator), Any, AbstractCosmologicalEmulators.LuxEmulator}
85+
5186
end # module MooncakeExt

0 commit comments

Comments
 (0)