Skip to content

Commit d7320fc

Browse files
SebastianM-Cclaude
andcommitted
Clean up the radiation module API
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e8e7fe0 commit d7320fc

2 files changed

Lines changed: 102 additions & 45 deletions

File tree

scripts/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
23
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
4+
DiffEqGPU = "071ae1c0-96b5-11e9-1965-c90190d839ea"
35
ElectronDynamicsModels = "acecdaf2-97b2-47e1-90eb-3efa7bb274e5"
46
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
57
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"

src/radiation.jl

Lines changed: 100 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
# Trajectory access (thread-safe interpolation wrapper)
2-
struct TrajectoryInterpolant{I, R, U}
2+
struct TrajectoryInterpolant{I, R, U, T}
33
itp::I # DataInterpolations interpolant (SVector{8} → SVector{8})
44
r_idxs::R # SVector{4, Int} for x⁰, x¹, x², x³
55
u_idxs::U # SVector{4, Int} for u⁰, u¹, u², u³
6+
K::T # q_e / (4π ε₀ c) — Liénard-Wiechert prefactor
67
end
78

89
function TrajectoryInterpolant(sol::SciMLBase.AbstractODESolution, x_syms, u_syms)
910
r_idxs = SVector{4, Int}(variable_index.((sol,), collect(x_syms)))
1011
u_idxs = SVector{4, Int}(variable_index.((sol,), collect(u_syms)))
1112
itp = CubicSpline(sol.u, sol.t; extrapolation = ExtrapolationType.Extension)
12-
return TrajectoryInterpolant(itp, r_idxs, u_idxs)
13+
sys = sol.prob.f.sys
14+
_ref_frame = _find_ref_frame(sys)
15+
K = sol.ps[_ref_frame.q_e / (4π * _ref_frame.ε₀ * _ref_frame.c)]
16+
return TrajectoryInterpolant(itp, r_idxs, u_idxs, K)
1317
end
1418

1519
function (t::TrajectoryInterpolant)(τ)
@@ -38,6 +42,27 @@ struct ObserverScreen{G, T, R}
3842
x⁰_samples::R # uniform observer-time sampling grid
3943
end
4044

45+
function Base.show(io::IO, s::ObserverScreen)
46+
Nx, Ny = length(s.x_grid), length(s.y_grid)
47+
N = length(s.x⁰_samples)
48+
δx⁰ = N > 1 ? step(s.x⁰_samples) : 0.0
49+
return print(io, "ObserverScreen($(Nx)×$(Ny) pixels, z=$(s.z), $(N) time samples, Δx⁰=$(δx⁰))")
50+
end
51+
52+
function Base.show(io::IO, ::MIME"text/plain", s::ObserverScreen)
53+
Nx, Ny = length(s.x_grid), length(s.y_grid)
54+
N = length(s.x⁰_samples)
55+
δx⁰ = N > 1 ? step(s.x⁰_samples) : 0.0
56+
println(io, "ObserverScreen")
57+
println(io, " pixels: $(Nx) × $(Ny)")
58+
println(io, " x range: [$(first(s.x_grid)), $(last(s.x_grid))]")
59+
println(io, " y range: [$(first(s.y_grid)), $(last(s.y_grid))]")
60+
println(io, " z: $(s.z)")
61+
println(io, " time samples: $(N)")
62+
println(io, " x⁰ range: [$(first(s.x⁰_samples)), $(last(s.x⁰_samples))]")
63+
return print(io, " Δx⁰: $(δx⁰)")
64+
end
65+
4166
# dτᵣ/dt = 1/(u⁰(τᵣ) - u⃗(τᵣ)·n̂(τᵣ, r_obs))
4267
function retarded_time_rhs(τᵣ, p, t)
4368
traj, r_obs = p
@@ -78,7 +103,28 @@ function retarded_time_problem(traj::TrajectoryInterpolant, screen::ObserverScre
78103
return EnsembleProblem(prob; prob_func = set_pixel, safetycopy = false)
79104
end
80105

81-
function accumulate_potential(trajs, screen, K; alg, ensemblealg = nothing, solve_kwargs...)
106+
"""
107+
accumulate_potential(trajs, screen, alg; solve_kwargs...)
108+
accumulate_potential(trajs, screen, alg, ensemblealg; solve_kwargs...)
109+
110+
Compute the Liénard-Wiechert 4-potential on `screen` from electron `trajs`.
111+
112+
For each electron trajectory, solves the retarded-time ODE to map observer time to
113+
proper time, then evaluates `Aμ = K u^μ / (X^R · u)` at uniform observer-time samples.
114+
Returns `A[k, μ, ix, iy]` — the time-domain 4-potential ready for FFT.
115+
116+
The two-argument `alg` form uses a `reinit!`-based integrator pool for efficient CPU
117+
threading. The four-argument form with `ensemblealg` uses `EnsembleProblem` for
118+
compatibility with GPU backends (e.g., `EnsembleGPUKernel`).
119+
120+
# Arguments
121+
- `trajs`: vector of `TrajectoryInterpolant` from [`trajectory_interpolants`](@ref)
122+
- `screen`: `ObserverScreen` defining pixel grid and observer-time samples
123+
- `alg`: ODE solver algorithm for the retarded-time problem (e.g., `Tsit5()`)
124+
- `ensemblealg`: (optional) ensemble algorithm (e.g., `EnsembleGPUKernel(backend)`)
125+
- `solve_kwargs...`: additional keyword arguments passed to the ODE solver
126+
"""
127+
function accumulate_potential(trajs::Vector{<:TrajectoryInterpolant}, screen::ObserverScreen, alg; solve_kwargs...)
82128
x⁰_samples = screen.x⁰_samples
83129
N_samples = length(x⁰_samples)
84130
Nx, Ny = length(screen.x_grid), length(screen.y_grid)
@@ -89,64 +135,73 @@ function accumulate_potential(trajs, screen, K; alg, ensemblealg = nothing, solv
89135
τi = first(traj.itp.t)
90136
τf = last(traj.itp.t)
91137

92-
if ensemblealg !== nothing
93-
# GPU / custom ensemble path
94-
rt_prob = retarded_time_problem(traj, screen)
95-
N_pixels = Nx * Ny
96-
rt_sol = solve(
97-
rt_prob, alg, ensemblealg;
98-
trajectories = N_pixels, saveat = x⁰_samples, solve_kwargs...
99-
)
100-
LI = LinearIndices((Nx, Ny))
101-
102-
Threads.@threads for ix in Base.OneTo(Nx)
103-
for iy in Base.OneTo(Ny)
104-
_accumulate_pixel!(A, traj, screen, K, rt_sol.u[LI[ix, iy]].u, ix, iy)
105-
end
106-
end
107-
else
108-
# CPU path: reinit! to avoid repeated solver initialization
109-
r_obs_0 = SVector{3}(screen.x_grid[1], screen.y_grid[1], screen.z)
110-
x⁰_i_0 = advanced_time(traj, τi, r_obs_0)
111-
x⁰_f_0 = advanced_time(traj, τf, r_obs_0)
112-
proto_prob = ODEProblem{false, SciMLBase.FullSpecialize}(
113-
retarded_time_rhs, τi, (x⁰_i_0, x⁰_f_0), (traj, r_obs_0)
114-
)
115-
116-
nworkers = Threads.nthreads()
117-
integ_pool = Channel{Any}(nworkers)
118-
for _ in 1:nworkers
119-
put!(integ_pool, init(proto_prob, alg; saveat = x⁰_samples, solve_kwargs...))
138+
r_obs_0 = SVector{3}(screen.x_grid[1], screen.y_grid[1], screen.z)
139+
x⁰_i_0 = advanced_time(traj, τi, r_obs_0)
140+
x⁰_f_0 = advanced_time(traj, τf, r_obs_0)
141+
proto_prob = ODEProblem{false, SciMLBase.FullSpecialize}(
142+
retarded_time_rhs, τi, (x⁰_i_0, x⁰_f_0), (traj, r_obs_0)
143+
)
144+
145+
nworkers = Threads.nthreads()
146+
integ_pool = Channel{Any}(nworkers)
147+
for _ in 1:nworkers
148+
put!(integ_pool, init(proto_prob, alg; saveat = x⁰_samples, solve_kwargs...))
149+
end
150+
151+
Threads.@threads for ix in Base.OneTo(Nx)
152+
integ = take!(integ_pool)
153+
for iy in Base.OneTo(Ny)
154+
r_obs = SVector{3}(screen.x_grid[ix], screen.y_grid[iy], screen.z)
155+
x⁰_i = advanced_time(traj, τi, r_obs)
156+
x⁰_f = advanced_time(traj, τf, r_obs)
157+
158+
integ.p = (traj, r_obs)
159+
reinit!(integ, τi; t0 = x⁰_i, tf = x⁰_f)
160+
solve!(integ)
161+
162+
_accumulate_pixel!(A, traj, screen, integ.sol.u, ix, iy)
120163
end
164+
put!(integ_pool, integ)
165+
end
166+
end
167+
168+
return A
169+
end
121170

122-
Threads.@threads for ix in Base.OneTo(Nx)
123-
integ = take!(integ_pool)
124-
for iy in Base.OneTo(Ny)
125-
r_obs = SVector{3}(screen.x_grid[ix], screen.y_grid[iy], screen.z)
126-
x⁰_i = advanced_time(traj, τi, r_obs)
127-
x⁰_f = advanced_time(traj, τf, r_obs)
171+
# Ensemble path (for GPU / custom ensemble algorithms)
172+
function accumulate_potential(trajs::Vector{<:TrajectoryInterpolant}, screen::ObserverScreen, alg, ensemblealg; solve_kwargs...)
173+
x⁰_samples = screen.x⁰_samples
174+
N_samples = length(x⁰_samples)
175+
Nx, Ny = length(screen.x_grid), length(screen.y_grid)
128176

129-
integ.p = (traj, r_obs)
130-
reinit!(integ, τi; t0 = x⁰_i, tf = x⁰_f)
131-
solve!(integ)
177+
A = zeros(N_samples, 4, Nx, Ny)
132178

133-
_accumulate_pixel!(A, traj, screen, K, integ.sol.u, ix, iy)
134-
end
135-
put!(integ_pool, integ)
179+
for (j, traj) in enumerate(trajs)
180+
rt_prob = retarded_time_problem(traj, screen)
181+
N_pixels = Nx * Ny
182+
rt_sol = solve(
183+
rt_prob, alg, ensemblealg;
184+
trajectories = N_pixels, saveat = x⁰_samples, solve_kwargs...
185+
)
186+
LI = LinearIndices((Nx, Ny))
187+
188+
Threads.@threads for ix in Base.OneTo(Nx)
189+
for iy in Base.OneTo(Ny)
190+
_accumulate_pixel!(A, traj, screen, rt_sol.u[LI[ix, iy]].u, ix, iy)
136191
end
137192
end
138193
end
139194

140195
return A
141196
end
142197

143-
function _accumulate_pixel!(A, traj, screen, K, τ_samples, ix, iy)
198+
function _accumulate_pixel!(A, traj, screen, τ_samples, ix, iy)
144199
r_obs = SVector{3}(screen.x_grid[ix], screen.y_grid[iy], screen.z)
145200
for (k, τ) in enumerate(τ_samples)
146201
rμ, uμ = traj(τ)
147202
disp = r_obs - rμ[SA[2, 3, 4]]
148203
xR = SVector{4}(norm(disp), disp[1], disp[2], disp[3])
149-
@views A[k, :, ix, iy] .+= K *./ m_dot(xR, uμ)
204+
@views A[k, :, ix, iy] .+= traj.K *./ m_dot(xR, uμ)
150205
end
151206
return
152207
end

0 commit comments

Comments
 (0)