Skip to content

Commit 0829602

Browse files
feat: [AI] allow marking systems as incompatible with symbolic AD
This is useful for `mtkcompile`/other passes which rephrase the system in a way that is incompatible with symbolic AD. Inline linear SCCs is one such pass. Co-authored by: Claude<noreply@anthropic.com>
1 parent 7148b9e commit 0829602

5 files changed

Lines changed: 36 additions & 0 deletions

File tree

lib/ModelingToolkitBase/src/ModelingToolkitBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ const set_scalar_metadata = setmetadata
353353
@public check_mutable_cache, store_to_mutable_cache!, should_invalidate_mutable_cache_entry
354354
@public convert_bindings_for_time_independent_system, get_w
355355
@public Both
356+
@public SymbolicADDisallowed, check_symbolic_ad_allowed
356357

357358
for prop in [SYS_PROPS; [:continuous_events, :discrete_events]]
358359
getter = Symbol(:get_, prop)

lib/ModelingToolkitBase/src/systems/codegen.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ Calculate the gradient of the equations of `sys` with respect to the independent
170170
`simplify` is forwarded to `Symbolics.expand_derivatives`.
171171
"""
172172
function calculate_tgrad(sys::System; simplify = false)
173+
check_symbolic_ad_allowed(sys)
173174
# We need to remove explicit time dependence on the unknown because when we
174175
# have `u(t) * t` we want to have the tgrad to be `u(t)` instead of `u'(t) *
175176
# t + u(t)`.
@@ -198,6 +199,7 @@ function calculate_jacobian(
198199
sys::System;
199200
sparse = false, simplify = false, dvs = unknowns(sys)
200201
)
202+
check_symbolic_ad_allowed(sys)
201203
eqs = full_equations(sys)
202204
rhs = SymbolicT[]
203205
sizehint!(rhs, length(eqs))

lib/ModelingToolkitBase/src/systems/system.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,37 @@ Utility metadata key for adding miscellaneous/one-off metadata to systems.
2323
"""
2424
abstract type MiscSystemData end
2525

26+
"""
27+
$TYPEDEF
28+
29+
Metadata key used to mark a system as incompatible with symbolic automatic differentiation.
30+
When set on a system via `setmetadata(sys, SymbolicADDisallowed, reason)`, any attempt to
31+
perform symbolic AD on the equations of that system (e.g. via `calculate_jacobian`,
32+
`calculate_tgrad`, `linearize_symbolic`, or during structural simplification) will throw
33+
an error. The value associated with this key should be a descriptive `String` explaining
34+
why symbolic AD is unsupported, or `true` if no explanation is available.
35+
36+
See also: [`check_symbolic_ad_allowed`](@ref).
37+
"""
38+
abstract type SymbolicADDisallowed end
39+
40+
"""
41+
check_symbolic_ad_allowed(sys::AbstractSystem)
42+
43+
Check whether `sys` supports symbolic automatic differentiation. Throws an `ArgumentError`
44+
if the system has been marked with [`SymbolicADDisallowed`](@ref).
45+
"""
46+
function check_symbolic_ad_allowed(sys::AbstractSystem)
47+
if SymbolicUtils.hasmetadata(sys, SymbolicADDisallowed)
48+
reason = SymbolicUtils.getmetadata(sys, SymbolicADDisallowed, nothing)
49+
msg = "System $(nameof(sys)) does not support symbolic automatic differentiation."
50+
if reason isa AbstractString && !isempty(reason)
51+
msg *= " $reason"
52+
end
53+
throw(ArgumentError(msg))
54+
end
55+
end
56+
2657
"""
2758
$(TYPEDEF)
2859

src/linearization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ function linearize_symbolic(
562562
kwargs...
563563
)
564564
sys = mtkcompile(sys; inputs, outputs, simplify, split, kwargs...)
565+
check_symbolic_ad_allowed(sys)
565566
diff_idxs, alge_idxs = eq_idxs(sys)
566567
sts = unknowns(sys)
567568
t = get_iv(sys)

src/problems/semilinearodeproblem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929

3030
if jac
31+
check_symbolic_ad_allowed(sys)
3132
Cjac = (C === nothing || !stiff_nonlinear) ? nothing : Symbolics.jacobian(C, dvs)
3233
_jac = generate_semiquadratic_jacobian(
3334
sys, A, B, C, Cjac; sparse, expression,

0 commit comments

Comments
 (0)