Skip to content

Commit 4969dc1

Browse files
authored
Merge pull request #9 from SebastianM-C/smc/fix
Fix bug in `GaussLaser` and simplify interface
2 parents 4d6c720 + 86960be commit 4969dc1

15 files changed

Lines changed: 687 additions & 324 deletions

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
*.jl.*.cov
22
*.jl.cov
33
*.jl.mem
4-
/Manifest.toml
4+
/Manifest*.toml
55
/docs/Manifest.toml
66
/docs/build/

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ authors = ["Sebastian Micluța-Câmpeanu <sebastian.mc95@proton.me> and contribu
77
HypergeometricFunctions = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
10+
ModelingToolkitBase = "7771a370-6774-4173-bd38-47e70ca0b839"
1011
PhysicalConstants = "5ad8b20f-a522-5ce9-bfc9-ddf1d5bda6ab"
1112
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1213
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -23,6 +24,7 @@ JET = "0.9, 0.10, 0.11"
2324
LaserTypes = "0.2"
2425
LinearAlgebra = "1.11"
2526
ModelingToolkit = "11"
27+
ModelingToolkitBase = "1.20.0"
2628
OrdinaryDiffEqNonlinearSolve = "1.10.0"
2729
OrdinaryDiffEqRosenbrock = "1.11.0"
2830
OrdinaryDiffEqVerner = "1.2.0"

scripts/benchmark_field_eval.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@ m_val = 1
1616

1717
# --- FieldEvaluator setup ---
1818
@named ref_frame = ProperFrame(:atomic)
19-
@named laser = LaguerreGaussLaser(
20-
wavelength=λ_val, a0=a₀_val, beam_waist=w₀_val,
21-
radial_index=p_val, azimuthal_index=m_val,
22-
ref_frame=ref_frame, temporal_profile=:constant)
19+
@named laser = LaguerreGaussLaser(;
20+
wavelength = λ_val, a0 = a₀_val, beam_waist = w₀_val,
21+
radial_index = p_val, azimuthal_index = m_val,
22+
ref_frame, temporal_profile = :constant
23+
)
2324

24-
fe = FieldEvaluator(laser, ref_frame)
25+
fe = FieldEvaluator(laser)
2526

2627
# --- LaserTypes setup ---
27-
lt_laser = LaserTypes.LaguerreGaussLaser(:atomic;
28-
λ=λ_val, a₀=a₀_val, w₀=w₀_val, p=p_val, m=m_val)
28+
lt_laser = LaserTypes.LaguerreGaussLaser(
29+
:atomic;
30+
λ = λ_val, a₀ = a₀_val, w₀ = w₀_val, p = p_val, m = m_val
31+
)
2932

3033
# Test point
3134
x, y, z = 0.3w₀_val, 0.1w₀_val, 0.5w₀_val
@@ -60,8 +63,10 @@ display(@benchmark LaserTypes.B($pos, $t_val, $lt_laser))
6063
println()
6164

6265
println("=== LaserTypes E + B ===")
63-
display(@benchmark begin
64-
LaserTypes.E($pos, $t_val, $lt_laser)
65-
LaserTypes.B($pos, $t_val, $lt_laser)
66-
end)
66+
display(
67+
@benchmark begin
68+
LaserTypes.E($pos, $t_val, $lt_laser)
69+
LaserTypes.B($pos, $t_val, $lt_laser)
70+
end
71+
)
6772
println()

scripts/ensemble.jl

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using GLMakie
99
using LaTeXStrings
1010

1111
const inch = 96
12-
const pt = 4/3
12+
const pt = 4 / 3
1313
const cm = inch / 2.54
1414

1515
# set_theme!(
@@ -22,21 +22,21 @@ const cm = inch / 2.54
2222
# Code is using Atomic Units !!!
2323
# natural constants
2424
c = 137.03599908330932 # speed of light
25-
qme = -1. # specific charge
25+
qme = -1.0 # specific charge
2626

2727
h = 2π
28-
α = 1/c
29-
ε₀=qme^2/(2α*h*c)
30-
μ₀=1/(ε₀*c^2)
28+
α = 1 / c
29+
ε₀ = qme^2 / (2α * h * c)
30+
μ₀ = 1 / (ε₀ * c^2)
3131

3232
# derived
3333
ω = 0.057
34-
τ = 150/ω
35-
λ = 2π*c/ω
34+
τ = 150 / ω
35+
λ = 2π * c / ω
3636
w₀ = 75λ
3737
Rmax = 3.25w₀
3838

39-
ξx, ξy = (1/2, im/2) .|> complex
39+
ξx, ξy = (1 / 2, im / 2) .|> complex
4040

4141
# Laser parameters in atomic units
4242
λ_au = λ
@@ -54,20 +54,20 @@ z₀ = 0.0
5454
# @named ref_frame = LabFrame(:atomic)
5555

5656
@named laser = LaguerreGaussLaser(
57-
wavelength=λ_au,
58-
a0=a₀,
59-
beam_waist=w₀_au,
60-
radial_index=p_index,
61-
azimuthal_index=m_index,
62-
ref_frame=ref_frame,
63-
temporal_profile=:gaussian, # Using Gaussian profile
64-
temporal_width=τ_fwhm,
65-
focus_position=z₀,
66-
polarization=:circular
57+
wavelength = λ_au,
58+
a0 = a₀,
59+
beam_waist = w₀_au,
60+
radial_index = p_index,
61+
azimuthal_index = m_index,
62+
ref_frame,
63+
temporal_profile = :gaussian, # Using Gaussian profile
64+
temporal_width = τ_fwhm,
65+
focus_position = z₀,
66+
polarization = :circular
6767
)
6868

6969
# Create electron system
70-
@named lg_elec = ClassicalElectron(; laser, ref_frame)
70+
@named lg_elec = ClassicalElectron(; laser)
7171

7272

7373
# Compile the system
@@ -79,21 +79,21 @@ sys = mtkcompile(lg_elec)
7979
tspan = (τi, τf)
8080

8181
# Create base problem with placeholder initial position
82-
x⁰ = [τi*c, 0.0, 0.0, 0.0]
82+
x⁰ = [τi * c, 0.0, 0.0, 0.0]
8383
u⁰ = [c, 0.0, 0.0, 0.0]
8484

8585
u0 = [
8686
(sys.x) => x⁰,
87-
(sys.u) => u⁰
87+
(sys.u) => u⁰,
8888
]
8989

90-
prob = ODEProblem{false, SciMLBase.FullSpecialize}(sys, u0, tspan, u0_constructor=SVector{8}, fully_determined=true)
91-
sol0 = solve(prob, Vern9(), reltol = 1e-15, abstol = 1e-12)
90+
prob = ODEProblem{false, SciMLBase.FullSpecialize}(sys, u0, tspan, u0_constructor = SVector{8}, fully_determined = true)
91+
sol0 = solve(prob, Vern9(), reltol = 1.0e-15, abstol = 1.0e-12)
9292

9393
# Sunflower pattern for initial positions
9494
N = 900
9595

96-
const ϕ = (1 + 5)/2
96+
const ϕ = (1 + 5) / 2
9797

9898
function radius(k, n, b)
9999
if k > n - b
@@ -105,21 +105,21 @@ end
105105

106106
function sunflower(n, α)
107107
points = []
108-
angle_stride = 2π/ϕ^2 # geodesic ? 360 * ϕ :
108+
angle_stride = 2π / ϕ^2 # geodesic ? 360 * ϕ :
109109
b = round(Int, α * sqrt(n)) # number of boundary points
110110

111111
for k in 1:n
112112
r = radius(k, n, b)
113113
θ = k * angle_stride
114-
append!(points, ([r * cos(θ), r * sin(θ)], ))
114+
append!(points, ([r * cos(θ), r * sin(θ)],))
115115
end
116116

117117
return points
118118
end
119119

120120
# Generate initial positions in sunflower pattern
121-
R₀ = Rmax*sunflower(N, 2)
122-
= [[τi*c, r..., 0.] for r in R₀]
121+
R₀ = Rmax * sunflower(N, 2)
122+
= [[τi * c, r..., 0.0] for r in R₀]
123123

124124
# Use SymbolicIndexingInterface to set positions
125125
set_x = setsym_oop(prob, [Initial(sys.x); Initial(sys.u)]);
@@ -136,59 +136,63 @@ function prob_func(prob, i, repeat)
136136
# Set new initial conditions
137137
u0, p = set_x(prob, SVector{8}(x_new..., u_new...))
138138

139-
remake(prob; u0, p)
139+
return remake(prob; u0, p)
140140
end
141141

142142
# Absolute error tolerance function
143143
function abserr(a₀)
144144
amp = log10(a₀)
145-
expo = -amp^2/27 + 32amp/27 - 220/27
146-
10^expo
145+
expo = -amp^2 / 27 + 32amp / 27 - 220 / 27
146+
return 10^expo
147147
end
148148

149149
# Create ensemble problem
150-
ensemble = EnsembleProblem(prob; prob_func, safetycopy=false)
150+
ensemble = EnsembleProblem(prob; prob_func, safetycopy = false)
151151

152152
# Solve ensemble
153-
solution = solve(ensemble, Vern9(), EnsembleThreads();
154-
reltol=1e-12, abstol=abserr(a₀),
155-
trajectories=N)
153+
solution = solve(
154+
ensemble, Vern9(), EnsembleThreads();
155+
reltol = 1.0e-12, abstol = abserr(a₀),
156+
trajectories = N
157+
)
156158

157159
# ru = solution.u[1](range(τi, τf, 1001), idxs = [sys.x; sys.u])
158160

159161
# Solve single trajectory for visualization (electron #1)
160162
x_single = SVector{8}(xμ[1]..., u⁰...)
161163
u0_single, p_single = set_x(prob, x_single)
162-
prob_single = remake(prob; u0=u0_single, p=p_single)
164+
prob_single = remake(prob; u0 = u0_single, p = p_single)
163165

164-
sol = solve(prob_single, Vern9(),
165-
reltol=1e-12,
166-
abstol=1e-20)
166+
sol = solve(
167+
prob_single, Vern9(),
168+
reltol = 1.0e-12,
169+
abstol = 1.0e-20
170+
)
167171

168172
#### eval field
169173

170174
_t = 0
171175
_x = sol[sys.x, 500]
172176

173-
x_sub = map(x->EvalAt(_t)(x[1])=>x[2], collect(sys.x .=> _x))
174-
eval_point = [laser.τ=>0; x_sub; sys.t => EvalAt(_t)(sys.x[1]) / c]
177+
x_sub = map(x -> EvalAt(_t)(x[1]) => x[2], collect(sys.x .=> _x))
178+
eval_point = [laser.τ => 0; x_sub; sys.t => EvalAt(_t)(sys.x[1]) / c]
175179

176180
all_eqs = Symbolics.fixpoint_sub(equations(laser), merge(initial_conditions(laser), initial_conditions(eval_point)))
177-
eq_dict = Dict(map(eq->eq.lhs=>eq.rhs, all_eqs[setdiff(1:19, 10:15)]))
181+
eq_dict = Dict(map(eq -> eq.lhs => eq.rhs, all_eqs[setdiff(1:19, 10:15)]))
178182
Symbolics.fixpoint_sub(all_eqs, eq_dict)
179183

180184
# using CairoMakie
181185
# Visualization
182-
fig = Figure(fontsize=14pt)
186+
fig = Figure(fontsize = 14pt)
183187
# ax = Axis3(fig[1, 1], aspect=:data)
184-
ax = Axis3(fig[1, 1], aspect=(1, 1, 1))
188+
ax = Axis3(fig[1, 1], aspect = (1, 1, 1))
185189

186190

187191
# Extract trajectory
188-
t_range = range(τi, τf, length=10001)
189-
x_traj = [sol(t, idxs=sys.x[2]) for t in t_range]
190-
y_traj = [sol(t, idxs=sys.x[3]) for t in t_range]
191-
z_traj = [sol(t, idxs=sys.x[4]) for t in t_range]
192+
t_range = range(τi, τf, length = 10001)
193+
x_traj = [sol(t, idxs = sys.x[2]) for t in t_range]
194+
y_traj = [sol(t, idxs = sys.x[3]) for t in t_range]
195+
z_traj = [sol(t, idxs = sys.x[4]) for t in t_range]
192196

193197
lines!(ax, x_traj, y_traj, z_traj)
194198

src/ElectronDynamicsModels.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module ElectronDynamicsModels
22

33
using ModelingToolkit
4-
using ModelingToolkit: SymbolicT, build_explicit_observed_function
5-
using SymbolicIndexingInterface: setsym_oop
4+
using ModelingToolkitBase: AbstractSystem, SymbolicT, build_explicit_observed_function, get_systems
5+
using SymbolicIndexingInterface: getname, setsym_oop
66
using PhysicalConstants, Unitful, UnitfulAtomic
77
using PhysicalConstants.CODATA2018: c_0, e, m_e, ε_0
88
using LinearAlgebra

src/dynamics.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@component function ParticleDynamics(; name, mass, ref_frame)
2+
@unpack c = ref_frame
23
iv = ModelingToolkit.get_iv(ref_frame)
34
D = Differential(iv)
45

@@ -12,7 +13,6 @@
1213

1314
if nameof(iv) ==
1415
τ = iv
15-
c = ref_frame.c
1616

1717
@variables begin
1818
t(iv), [description = "Universal time"]
@@ -36,7 +36,6 @@
3636
System(eqs, iv, [t, γ, x, u, p, F_total], [m]; name, systems = [ref_frame])
3737
elseif nameof(iv) == :t
3838
t = iv
39-
c = ref_frame.c
4039

4140
@variables begin
4241
τ(t), [description = "Proper time"]

0 commit comments

Comments
 (0)