|
| 1 | +# Symbolic PINN Parser MVP |
| 2 | + |
| 3 | +function _build_domain_grids(domains; n_points=20) |
| 4 | + grids = Dict{Any, Vector{Float64}}() |
| 5 | + for dom in domains |
| 6 | + iv = dom.variables |
| 7 | + interval = dom.domain |
| 8 | + lo = DomainSets.leftendpoint(interval) |
| 9 | + hi = DomainSets.rightendpoint(interval) |
| 10 | + grids[iv] = collect(range(Float64(lo), Float64(hi), length=n_points)) |
| 11 | + end |
| 12 | + return grids |
| 13 | +end |
| 14 | + |
| 15 | +function _replace_dv_calls(expr, dvs, ivs, nn_out_sym) |
| 16 | + dv_op_names = [string(Symbolics.operation(Symbolics.unwrap(dv))) for dv in dvs] |
| 17 | + expr_unwrapped = Symbolics.unwrap(expr) |
| 18 | + |
| 19 | + pw = SymbolicUtils.Postwalk(x -> begin |
| 20 | + if SymbolicUtils.istree(x) |
| 21 | + op = Symbolics.operation(x) |
| 22 | + op_name = string(op) |
| 23 | + idx = findfirst(==(op_name), dv_op_names) |
| 24 | + if idx !== nothing |
| 25 | + args = collect(Symbolics.arguments(x)) |
| 26 | + coord_map = Dict{Any, Any}() |
| 27 | + for i in 1:min(length(ivs), length(args)) |
| 28 | + coord_map[ivs[i]] = args[i] |
| 29 | + end |
| 30 | + nn_at_args = substitute(nn_out_sym[idx], coord_map) |
| 31 | + return Symbolics.unwrap(nn_at_args) |
| 32 | + end |
| 33 | + end |
| 34 | + return x |
| 35 | + end) |
| 36 | + |
| 37 | + return Num(pw(expr_unwrapped)) |
| 38 | +end |
| 39 | + |
| 40 | +function _build_compact_symbolic_loss(eqs, bcs, dvs, ivs; n_points=20, bc_weight=10.0) |
| 41 | + iv_list = join(string.(ivs), ", ") |
| 42 | + |
| 43 | + pde_terms = String[] |
| 44 | + for eq in eqs |
| 45 | + res_str = string(eq.lhs - eq.rhs) |
| 46 | + for (j, dv) in enumerate(dvs) |
| 47 | + dv_op = split(string(dv), "(")[1] |
| 48 | + call_pat = Regex("\\b" * dv_op * "\\([^\\)]*\\)") |
| 49 | + res_str = replace(res_str, call_pat => "NN_$(j)($(iv_list))") |
| 50 | + end |
| 51 | + push!(pde_terms, "(" * res_str * ")^2") |
| 52 | + end |
| 53 | + |
| 54 | + bc_terms = String[] |
| 55 | + for bc in bcs |
| 56 | + res_bc_str = string(bc.lhs - bc.rhs) |
| 57 | + for (j, dv) in enumerate(dvs) |
| 58 | + dv_op = split(string(dv), "(")[1] |
| 59 | + call_pat = Regex("\\b" * dv_op * "\\([^\\)]*\\)") |
| 60 | + res_bc_str = replace(res_bc_str, call_pat => "NN_$(j)($(iv_list))") |
| 61 | + end |
| 62 | + push!(bc_terms, "(" * res_bc_str * ")^2") |
| 63 | + end |
| 64 | + |
| 65 | + io = IOBuffer() |
| 66 | + println(io, "# Compact symbolic PINN loss template") |
| 67 | + println(io, "# NN_j(...) are registered neural outputs (not expanded layer-by-layer).") |
| 68 | + println(io) |
| 69 | + |
| 70 | + println(io, "PDE residual terms:") |
| 71 | + if isempty(pde_terms) |
| 72 | + println(io, "0") |
| 73 | + else |
| 74 | + println(io, join(pde_terms, " + ")) |
| 75 | + end |
| 76 | + println(io) |
| 77 | + |
| 78 | + println(io, "Boundary residual terms:") |
| 79 | + if isempty(bc_terms) |
| 80 | + println(io, "0") |
| 81 | + else |
| 82 | + println(io, join(bc_terms, " + ")) |
| 83 | + end |
| 84 | + println(io) |
| 85 | + |
| 86 | + println(io, "Loss template:") |
| 87 | + println( |
| 88 | + io, |
| 89 | + "sum_over_" * string(n_points) * "_point_grid(PDE residual terms) + " * |
| 90 | + string(bc_weight) * " * (Boundary residual terms)" |
| 91 | + ) |
| 92 | + |
| 93 | + return String(take!(io)) |
| 94 | +end |
| 95 | + |
| 96 | +function build_pinn_loss( |
| 97 | + pde_system, |
| 98 | + chain = nothing; |
| 99 | + width=16, |
| 100 | + depth=2, |
| 101 | + activation=tanh, |
| 102 | + n_points=20, |
| 103 | + bc_weight=10.0, |
| 104 | + show_symbolic_expression=true, |
| 105 | + symbolic_expression_path="pinn_symbolic_expression.txt", |
| 106 | + symbolic_expression_style=:compact, |
| 107 | + rng=Random.default_rng() |
| 108 | +) |
| 109 | + eqs = collect(pde_system.eqs) |
| 110 | + bcs = collect(pde_system.bcs) |
| 111 | + dvs = collect(pde_system.dvs) |
| 112 | + ivs = collect(pde_system.ivs) |
| 113 | + doms = collect(pde_system.domain) |
| 114 | + |
| 115 | + if chain === nothing |
| 116 | + # We manually construct a fully connected network as fallback |
| 117 | + # If ModelingToolkitNeuralNets is not available we could build manually via Lux |
| 118 | + chain = Lux.Chain(Lux.Dense(length(ivs), width, activation), Lux.Dense(width, length(dvs))) |
| 119 | + end |
| 120 | + |
| 121 | + p_init, st = Lux.setup(rng, chain) |
| 122 | + p_ca = ComponentArray(p_init) |
| 123 | + @variables p_sym[1:length(p_ca)] |
| 124 | + p_sym_ca = ComponentArray(p_sym, getaxes(p_ca)) |
| 125 | + |
| 126 | + nn_out_sym, _ = Lux.apply(chain, ivs, p_sym_ca, st) |
| 127 | + |
| 128 | + grids = _build_domain_grids(doms; n_points=n_points) |
| 129 | + |
| 130 | + # 1. Compile PDE residuals into functions |
| 131 | + compiled_res_funcs = [] |
| 132 | + for eq in eqs |
| 133 | + res = eq.lhs - eq.rhs |
| 134 | + res = _replace_dv_calls(res, dvs, ivs, nn_out_sym) |
| 135 | + res = expand_derivatives(res) |
| 136 | + |
| 137 | + # Compile to a function taking (p, ivs...) |
| 138 | + bf = Symbolics.build_function(res, p_sym, ivs..., expression=Val(false)) |
| 139 | + push!(compiled_res_funcs, bf isa Tuple ? bf[1] : bf) |
| 140 | + end |
| 141 | + |
| 142 | + # 2. Compile Boundary Condition residuals |
| 143 | + compiled_bc_funcs = [] |
| 144 | + for bc in bcs |
| 145 | + res_bc = bc.lhs - bc.rhs |
| 146 | + res_bc = _replace_dv_calls(res_bc, dvs, ivs, nn_out_sym) |
| 147 | + res_bc = expand_derivatives(res_bc) |
| 148 | + |
| 149 | + # Evaluate boundaries across the grid |
| 150 | + bf_bc = Symbolics.build_function(res_bc, p_sym, ivs..., expression=Val(false)) |
| 151 | + push!(compiled_bc_funcs, bf_bc isa Tuple ? bf_bc[1] : bf_bc) |
| 152 | + end |
| 153 | + |
| 154 | + # Grid arrays |
| 155 | + iv_vecs = [grids[iv] for iv in ivs] |
| 156 | + |
| 157 | + if show_symbolic_expression |
| 158 | + output_path = isabspath(symbolic_expression_path) ? symbolic_expression_path : joinpath(pwd(), symbolic_expression_path) |
| 159 | + |
| 160 | + expr_text = if symbolic_expression_style == :expanded |
| 161 | + "Expanded symbolic expression tracking full grid is disabled in Lazy Grid Sum mode." |
| 162 | + elseif symbolic_expression_style == :compact |
| 163 | + _build_compact_symbolic_loss(eqs, bcs, dvs, ivs; n_points=n_points, bc_weight=bc_weight) |
| 164 | + else |
| 165 | + error("symbolic_expression_style must be :compact or :expanded") |
| 166 | + end |
| 167 | + |
| 168 | + open(output_path, "w") do io |
| 169 | + print(io, expr_text) |
| 170 | + end |
| 171 | + println("Symbolic PINN loss expression (" * String(symbolic_expression_style) * ") saved to: " * output_path) |
| 172 | + end |
| 173 | + |
| 174 | + loss_func = (p) -> begin |
| 175 | + loss = zero(eltype(p)) |
| 176 | + |
| 177 | + # PDE Residuals over grid |
| 178 | + for cpt in Iterators.product(iv_vecs...) |
| 179 | + for rf in compiled_res_funcs |
| 180 | + r = rf(p, cpt...) |
| 181 | + loss += r^2 |
| 182 | + end |
| 183 | + end |
| 184 | + |
| 185 | + # Boundary Conditions over the grid |
| 186 | + for cpt in Iterators.product(iv_vecs...) |
| 187 | + for bcf in compiled_bc_funcs |
| 188 | + r_bc = bcf(p, cpt...) |
| 189 | + loss += bc_weight * r_bc^2 |
| 190 | + end |
| 191 | + end |
| 192 | + |
| 193 | + return loss |
| 194 | + end |
| 195 | + |
| 196 | + return loss_func, p_ca, chain, st |
| 197 | +end |
0 commit comments