|
| 1 | +# Trajectory access (thread-safe interpolation wrapper) |
| 2 | +struct TrajectoryInterpolant{I, R, U} |
| 3 | + itp::I # DataInterpolations interpolant (SVector{8} → SVector{8}) |
| 4 | + r_idxs::R # SVector{4, Int} for x⁰, x¹, x², x³ |
| 5 | + u_idxs::U # SVector{4, Int} for u⁰, u¹, u², u³ |
| 6 | +end |
| 7 | + |
| 8 | +function TrajectoryInterpolant(sol::SciMLBase.AbstractODESolution, x_syms, u_syms) |
| 9 | + r_idxs = SVector{4, Int}(variable_index.((sol,), collect(x_syms))) |
| 10 | + u_idxs = SVector{4, Int}(variable_index.((sol,), collect(u_syms))) |
| 11 | + itp = CubicSpline(sol.u, sol.t; extrapolation = ExtrapolationType.Extension) |
| 12 | + return TrajectoryInterpolant(itp, r_idxs, u_idxs) |
| 13 | +end |
| 14 | + |
| 15 | +function (t::TrajectoryInterpolant)(τ) |
| 16 | + v = t.itp(τ) |
| 17 | + rμ = v[t.r_idxs] |
| 18 | + uμ = v[t.u_idxs] |
| 19 | + return (rμ, uμ) |
| 20 | +end |
| 21 | + |
| 22 | +# Convenience: extract all trajectories from ensemble solution |
| 23 | +function trajectories(ensemble_sol::EnsembleSolution) |
| 24 | + sys = ensemble_sol.u[1].prob.f.sys |
| 25 | + return [TrajectoryInterpolant(ensemble_sol.u[i], sys.x, sys.u) for i in axes(ensemble_sol.u, 1)] |
| 26 | +end |
| 27 | + |
| 28 | +# Observer geometry + temporal sampling |
| 29 | +struct ObserverScreen{G, T, R} |
| 30 | + x_grid::G # e.g., LinRange for x |
| 31 | + y_grid::G # e.g., LinRange for y |
| 32 | + z::T # screen distance |
| 33 | + x⁰_samples::R # uniform observer-time sampling grid |
| 34 | +end |
| 35 | + |
| 36 | +# dτᵣ/dt = 1/(u⁰(τᵣ) - u⃗(τᵣ)·n̂(τᵣ, r_obs)) |
| 37 | +function retarded_time_rhs(τᵣ, p, t) |
| 38 | + traj, r_obs = p |
| 39 | + rμ, uμ = traj(τᵣ) |
| 40 | + rⁱ = rμ[SA[2, 3, 4]] |
| 41 | + n̂ = (r_obs - rⁱ) * inv(norm(r_obs - rⁱ)) |
| 42 | + return inv(uμ[1] - uμ[SA[2, 3, 4]] ⋅ n̂) |
| 43 | +end |
| 44 | + |
| 45 | +function advanced_time(traj, τ, x_obs) |
| 46 | + rμ, _ = traj(τ) |
| 47 | + return rμ[1] + norm(x_obs - rμ[SA[2, 3, 4]]) |
| 48 | +end |
| 49 | + |
| 50 | +function retarded_time_problem(traj::TrajectoryInterpolant, screen::ObserverScreen) |
| 51 | + Nx, Ny = length(screen.x_grid), length(screen.y_grid) |
| 52 | + CI = CartesianIndices((Nx, Ny)) |
| 53 | + |
| 54 | + r_obs_0 = SVector{3}(screen.x_grid[1], screen.y_grid[1], screen.z) |
| 55 | + τi = first(traj.itp.t) |
| 56 | + τf = last(traj.itp.t) |
| 57 | + (x⁰_i_0, x⁰_f_0) = advanced_time(traj, τi, r_obs_0), advanced_time(traj, τf, r_obs_0) |
| 58 | + |
| 59 | + prob = ODEProblem{false, SciMLBase.FullSpecialize}( |
| 60 | + retarded_time_rhs, |
| 61 | + τi, |
| 62 | + (x⁰_i_0, x⁰_f_0), |
| 63 | + (traj, r_obs_0), |
| 64 | + ) |
| 65 | + |
| 66 | + function set_pixel(prob, i, repeat) |
| 67 | + ix, iy = Tuple(CI[i]) |
| 68 | + r_obs = SVector{3}(screen.x_grid[ix], screen.y_grid[iy], screen.z) |
| 69 | + (x⁰_i, x⁰_f) = advanced_time(traj, τi, r_obs), advanced_time(traj, τf, r_obs) |
| 70 | + return remake(prob; p = (traj, r_obs), tspan = (x⁰_i, x⁰_f)) |
| 71 | + end |
| 72 | + |
| 73 | + return EnsembleProblem(prob; prob_func = set_pixel, safetycopy = false) |
| 74 | +end |
| 75 | + |
| 76 | +function accumulate_potential(trajs, screen, K; alg, ensemblealg = nothing, solve_kwargs...) |
| 77 | + x⁰_samples = screen.x⁰_samples |
| 78 | + N_samples = length(x⁰_samples) |
| 79 | + Nx, Ny = length(screen.x_grid), length(screen.y_grid) |
| 80 | + |
| 81 | + A = zeros(N_samples, 4, Nx, Ny) |
| 82 | + |
| 83 | + for (j, traj) in enumerate(trajs) |
| 84 | + τi = first(traj.itp.t) |
| 85 | + τf = last(traj.itp.t) |
| 86 | + |
| 87 | + if ensemblealg !== nothing |
| 88 | + # GPU / custom ensemble path |
| 89 | + rt_prob = retarded_time_problem(traj, screen) |
| 90 | + N_pixels = Nx * Ny |
| 91 | + rt_sol = solve( |
| 92 | + rt_prob, alg, ensemblealg; |
| 93 | + trajectories = N_pixels, saveat = x⁰_samples, solve_kwargs... |
| 94 | + ) |
| 95 | + LI = LinearIndices((Nx, Ny)) |
| 96 | + |
| 97 | + Threads.@threads for ix in Base.OneTo(Nx) |
| 98 | + for iy in Base.OneTo(Ny) |
| 99 | + _accumulate_pixel!(A, traj, screen, K, rt_sol.u[LI[ix, iy]].u, ix, iy) |
| 100 | + end |
| 101 | + end |
| 102 | + else |
| 103 | + # CPU path: reinit! to avoid repeated solver initialization |
| 104 | + r_obs_0 = SVector{3}(screen.x_grid[1], screen.y_grid[1], screen.z) |
| 105 | + x⁰_i_0 = advanced_time(traj, τi, r_obs_0) |
| 106 | + x⁰_f_0 = advanced_time(traj, τf, r_obs_0) |
| 107 | + proto_prob = ODEProblem{false, SciMLBase.FullSpecialize}( |
| 108 | + retarded_time_rhs, τi, (x⁰_i_0, x⁰_f_0), (traj, r_obs_0) |
| 109 | + ) |
| 110 | + |
| 111 | + nworkers = Threads.nthreads() |
| 112 | + integ_pool = Channel{Any}(nworkers) |
| 113 | + for _ in 1:nworkers |
| 114 | + put!(integ_pool, init(proto_prob, alg; saveat = x⁰_samples, solve_kwargs...)) |
| 115 | + end |
| 116 | + |
| 117 | + Threads.@threads for ix in Base.OneTo(Nx) |
| 118 | + integ = take!(integ_pool) |
| 119 | + for iy in Base.OneTo(Ny) |
| 120 | + r_obs = SVector{3}(screen.x_grid[ix], screen.y_grid[iy], screen.z) |
| 121 | + x⁰_i = advanced_time(traj, τi, r_obs) |
| 122 | + x⁰_f = advanced_time(traj, τf, r_obs) |
| 123 | + |
| 124 | + integ.p = (traj, r_obs) |
| 125 | + reinit!(integ, τi; t0 = x⁰_i, tf = x⁰_f) |
| 126 | + solve!(integ) |
| 127 | + |
| 128 | + _accumulate_pixel!(A, traj, screen, K, integ.sol.u, ix, iy) |
| 129 | + end |
| 130 | + put!(integ_pool, integ) |
| 131 | + end |
| 132 | + end |
| 133 | + end |
| 134 | + |
| 135 | + return A |
| 136 | +end |
| 137 | + |
| 138 | +function _accumulate_pixel!(A, traj, screen, K, τ_samples, ix, iy) |
| 139 | + r_obs = SVector{3}(screen.x_grid[ix], screen.y_grid[iy], screen.z) |
| 140 | + for (k, τ) in enumerate(τ_samples) |
| 141 | + rμ, uμ = traj(τ) |
| 142 | + disp = r_obs - rμ[SA[2, 3, 4]] |
| 143 | + xR = SVector{4}(norm(disp), disp[1], disp[2], disp[3]) |
| 144 | + @views A[k, :, ix, iy] .+= K * uμ ./ m_dot(xR, uμ) |
| 145 | + end |
| 146 | + return |
| 147 | +end |
0 commit comments