Skip to content

Commit 715d0a9

Browse files
committed
Experimental Symbolic PINN Parser Module
1 parent 7790c41 commit 715d0a9

4 files changed

Lines changed: 285 additions & 0 deletions

File tree

demo_symbolic_expression.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Compact symbolic PINN loss template
2+
# NN_j(...) are registered neural outputs (not expanded layer-by-layer).
3+
4+
PDE residual terms:
5+
(Differential(t, 1)(NN_1(t, x)) + Differential(x, 1)(NN_1(t, x)))^2
6+
7+
Boundary residual terms:
8+
(NN_1(t, x) - sin(6.283185307179586x))^2 + (NN_1(t, x) - sin(-6.283185307179586t))^2 + (-sin(6.283185307179586(1.0 - t)) + NN_1(t, x))^2
9+
10+
Loss template:
11+
sum_over_10_point_grid(PDE residual terms) + 15.0 * (Boundary residual terms)

demo_symbolic_parser.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import Pkg
2+
Pkg.activate(temp=true)
3+
Pkg.develop(path=".")
4+
Pkg.add(["ModelingToolkit", "Lux", "DomainSets", "Optimization", "OptimizationOptimisers", "ComponentArrays", "Plots", "ModelingToolkitNeuralNets"])
5+
6+
using NeuralPDE
7+
using ModelingToolkit
8+
using DomainSets
9+
using ComponentArrays
10+
using Optimization
11+
using OptimizationOptimisers
12+
using Lux
13+
using Plots
14+
using Random
15+
16+
Random.seed!(42)
17+
18+
println("==== Symbolic PINN Parser Demo (MVP) ====")
19+
20+
# Advection Equation MVP problem:
21+
# ∂u/∂t + 1.0 * ∂u/∂x = 0
22+
# u(0, x) = sin(2pi * x)
23+
# True solution: u(t, x) = sin(2pi * (x - t))
24+
25+
@parameters t x
26+
@variables u(..)
27+
Dt = Differential(t)
28+
Dx = Differential(x)
29+
30+
eq = Dt(u(t, x)) + 1.0 * Dx(u(t, x)) ~ 0
31+
32+
bcs = [
33+
u(0.0, x) ~ sin(2pi * x),
34+
u(t, 0.0) ~ sin(-2pi * t),
35+
u(t, 1.0) ~ sin(2pi * (1.0 - t))
36+
]
37+
38+
domains = [
39+
t Interval(0.0, 1.0),
40+
x Interval(0.0, 1.0)
41+
]
42+
43+
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t, x)])
44+
45+
# Use the newly ported MVP parser
46+
loss_func, p_init, chain, st = build_pinn_loss(
47+
pde_system;
48+
width=6,
49+
depth=1,
50+
activation=tanh,
51+
n_points=10,
52+
bc_weight=15.0,
53+
symbolic_expression_path="demo_symbolic_expression.txt"
54+
)
55+
56+
# Convert to Optimization problem
57+
obj = OptimizationFunction((theta, _) -> loss_func(theta), Optimization.AutoForwardDiff())
58+
prob = OptimizationProblem(obj, collect(p_init), nothing)
59+
60+
println("Starting training for Advection Equation via Symbolic Loss Function...")
61+
res = Optimization.solve(prob, OptimizationOptimisers.Adam(0.01), maxiters=1500)
62+
println("Training complete. Final objective = ", res.objective)
63+
64+
# Evaluate and print comparison
65+
theta_final = ComponentArray(res.u, getaxes(p_init))
66+
67+
xs = collect(range(0.0, 1.0, length=10))
68+
u_pred = [first(Lux.apply(chain, [0.5, xi], theta_final, st)[1]) for xi in xs]
69+
u_true = [sin(2pi * (xi - 0.5)) for xi in xs]
70+
71+
println("\n=== Results ===")
72+
println("Predictions: ", round.(u_pred, digits=4))
73+
println("Exact: ", round.(u_true, digits=4))
74+
println("Max error: ", round(maximum(abs.(u_pred .- u_true)), digits=4))
75+
println("\nSymbolic PINN parser executed successfully!\n")

src/NeuralPDE.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,15 @@ include("PDE_BPINN.jl")
9292
include("dgm.jl")
9393
include("NN_SDE_solve.jl")
9494
include("NN_SDE_weaksolve.jl")
95+
include("symbolic_pinn_parser.jl")
9596

9697
export PINOODE
9798
export NNODE, NNDAE
9899
export BNNODE, ahmc_bayesian_pinn_ode, ahmc_bayesian_pinn_pde
99100
export NNSDE
100101
export SDEPINN
101102
export PhysicsInformedNN, discretize
103+
export build_pinn_loss
102104
export BPINNsolution, BayesianPINN
103105
export DeepGalerkin
104106

src/symbolic_pinn_parser.jl

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Symbolic PINN Parser MVP
2+
3+
function _build_domain_grids(domains; n_points=20)
4+
grids = Dict{Any, Vector{Float64}}()
5+
for dom in domains
6+
iv = dom.variables
7+
interval = dom.domain
8+
lo = DomainSets.leftendpoint(interval)
9+
hi = DomainSets.rightendpoint(interval)
10+
grids[iv] = collect(range(Float64(lo), Float64(hi), length=n_points))
11+
end
12+
return grids
13+
end
14+
15+
function _replace_dv_calls(expr, dvs, ivs, nn_out_sym)
16+
dv_op_names = [string(Symbolics.operation(Symbolics.unwrap(dv))) for dv in dvs]
17+
expr_unwrapped = Symbolics.unwrap(expr)
18+
19+
pw = SymbolicUtils.Postwalk(x -> begin
20+
if SymbolicUtils.istree(x)
21+
op = Symbolics.operation(x)
22+
op_name = string(op)
23+
idx = findfirst(==(op_name), dv_op_names)
24+
if idx !== nothing
25+
args = collect(Symbolics.arguments(x))
26+
coord_map = Dict{Any, Any}()
27+
for i in 1:min(length(ivs), length(args))
28+
coord_map[ivs[i]] = args[i]
29+
end
30+
nn_at_args = substitute(nn_out_sym[idx], coord_map)
31+
return Symbolics.unwrap(nn_at_args)
32+
end
33+
end
34+
return x
35+
end)
36+
37+
return Num(pw(expr_unwrapped))
38+
end
39+
40+
function _build_compact_symbolic_loss(eqs, bcs, dvs, ivs; n_points=20, bc_weight=10.0)
41+
iv_list = join(string.(ivs), ", ")
42+
43+
pde_terms = String[]
44+
for eq in eqs
45+
res_str = string(eq.lhs - eq.rhs)
46+
for (j, dv) in enumerate(dvs)
47+
dv_op = split(string(dv), "(")[1]
48+
call_pat = Regex("\\b" * dv_op * "\\([^\\)]*\\)")
49+
res_str = replace(res_str, call_pat => "NN_$(j)($(iv_list))")
50+
end
51+
push!(pde_terms, "(" * res_str * ")^2")
52+
end
53+
54+
bc_terms = String[]
55+
for bc in bcs
56+
res_bc_str = string(bc.lhs - bc.rhs)
57+
for (j, dv) in enumerate(dvs)
58+
dv_op = split(string(dv), "(")[1]
59+
call_pat = Regex("\\b" * dv_op * "\\([^\\)]*\\)")
60+
res_bc_str = replace(res_bc_str, call_pat => "NN_$(j)($(iv_list))")
61+
end
62+
push!(bc_terms, "(" * res_bc_str * ")^2")
63+
end
64+
65+
io = IOBuffer()
66+
println(io, "# Compact symbolic PINN loss template")
67+
println(io, "# NN_j(...) are registered neural outputs (not expanded layer-by-layer).")
68+
println(io)
69+
70+
println(io, "PDE residual terms:")
71+
if isempty(pde_terms)
72+
println(io, "0")
73+
else
74+
println(io, join(pde_terms, " + "))
75+
end
76+
println(io)
77+
78+
println(io, "Boundary residual terms:")
79+
if isempty(bc_terms)
80+
println(io, "0")
81+
else
82+
println(io, join(bc_terms, " + "))
83+
end
84+
println(io)
85+
86+
println(io, "Loss template:")
87+
println(
88+
io,
89+
"sum_over_" * string(n_points) * "_point_grid(PDE residual terms) + " *
90+
string(bc_weight) * " * (Boundary residual terms)"
91+
)
92+
93+
return String(take!(io))
94+
end
95+
96+
function build_pinn_loss(
97+
pde_system,
98+
chain = nothing;
99+
width=16,
100+
depth=2,
101+
activation=tanh,
102+
n_points=20,
103+
bc_weight=10.0,
104+
show_symbolic_expression=true,
105+
symbolic_expression_path="pinn_symbolic_expression.txt",
106+
symbolic_expression_style=:compact,
107+
rng=Random.default_rng()
108+
)
109+
eqs = collect(pde_system.eqs)
110+
bcs = collect(pde_system.bcs)
111+
dvs = collect(pde_system.dvs)
112+
ivs = collect(pde_system.ivs)
113+
doms = collect(pde_system.domain)
114+
115+
if chain === nothing
116+
# We manually construct a fully connected network as fallback
117+
# If ModelingToolkitNeuralNets is not available we could build manually via Lux
118+
chain = Lux.Chain(Lux.Dense(length(ivs), width, activation), Lux.Dense(width, length(dvs)))
119+
end
120+
121+
p_init, st = Lux.setup(rng, chain)
122+
p_ca = ComponentArray(p_init)
123+
@variables p_sym[1:length(p_ca)]
124+
p_sym_ca = ComponentArray(p_sym, getaxes(p_ca))
125+
126+
nn_out_sym, _ = Lux.apply(chain, ivs, p_sym_ca, st)
127+
128+
grids = _build_domain_grids(doms; n_points=n_points)
129+
130+
# 1. Compile PDE residuals into functions
131+
compiled_res_funcs = []
132+
for eq in eqs
133+
res = eq.lhs - eq.rhs
134+
res = _replace_dv_calls(res, dvs, ivs, nn_out_sym)
135+
res = expand_derivatives(res)
136+
137+
# Compile to a function taking (p, ivs...)
138+
bf = Symbolics.build_function(res, p_sym, ivs..., expression=Val(false))
139+
push!(compiled_res_funcs, bf isa Tuple ? bf[1] : bf)
140+
end
141+
142+
# 2. Compile Boundary Condition residuals
143+
compiled_bc_funcs = []
144+
for bc in bcs
145+
res_bc = bc.lhs - bc.rhs
146+
res_bc = _replace_dv_calls(res_bc, dvs, ivs, nn_out_sym)
147+
res_bc = expand_derivatives(res_bc)
148+
149+
# Evaluate boundaries across the grid
150+
bf_bc = Symbolics.build_function(res_bc, p_sym, ivs..., expression=Val(false))
151+
push!(compiled_bc_funcs, bf_bc isa Tuple ? bf_bc[1] : bf_bc)
152+
end
153+
154+
# Grid arrays
155+
iv_vecs = [grids[iv] for iv in ivs]
156+
157+
if show_symbolic_expression
158+
output_path = isabspath(symbolic_expression_path) ? symbolic_expression_path : joinpath(pwd(), symbolic_expression_path)
159+
160+
expr_text = if symbolic_expression_style == :expanded
161+
"Expanded symbolic expression tracking full grid is disabled in Lazy Grid Sum mode."
162+
elseif symbolic_expression_style == :compact
163+
_build_compact_symbolic_loss(eqs, bcs, dvs, ivs; n_points=n_points, bc_weight=bc_weight)
164+
else
165+
error("symbolic_expression_style must be :compact or :expanded")
166+
end
167+
168+
open(output_path, "w") do io
169+
print(io, expr_text)
170+
end
171+
println("Symbolic PINN loss expression (" * String(symbolic_expression_style) * ") saved to: " * output_path)
172+
end
173+
174+
loss_func = (p) -> begin
175+
loss = zero(eltype(p))
176+
177+
# PDE Residuals over grid
178+
for cpt in Iterators.product(iv_vecs...)
179+
for rf in compiled_res_funcs
180+
r = rf(p, cpt...)
181+
loss += r^2
182+
end
183+
end
184+
185+
# Boundary Conditions over the grid
186+
for cpt in Iterators.product(iv_vecs...)
187+
for bcf in compiled_bc_funcs
188+
r_bc = bcf(p, cpt...)
189+
loss += bc_weight * r_bc^2
190+
end
191+
end
192+
193+
return loss
194+
end
195+
196+
return loss_func, p_ca, chain, st
197+
end

0 commit comments

Comments
 (0)