Skip to content

Commit 5df24d1

Browse files
authored
Cache subexpressions when building MOI.ScalarNonlinearExpression (#4032)
1 parent 1033c70 commit 5df24d1

7 files changed

Lines changed: 390 additions & 9 deletions

File tree

.vale.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ BasedOnStyles = Vale, Google
1515

1616
# TODO(odow): fix all of these
1717
Google.Ellipses = OFF
18+
Google.EmDash = OF
1819
Google.Exclamation = OFF
1920
Google.FirstPerson = OFF
2021
Google.OptionalPlurals = OFF

src/JuMP.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ mutable struct GenericModel{T<:Real} <: AbstractModel
162162
# A dictionary to store timing information from the JuMP macros.
163163
enable_macro_timing::Bool
164164
macro_times::Dict{Tuple{LineNumberNode,String},Float64}
165+
# A cache to track common subexpressions based on their `objectid`.
166+
subexpressions::Dict{UInt64,MOI.ScalarNonlinearFunction}
165167
end
166168

167169
value_type(::Type{GenericModel{T}}) where {T} = T
@@ -276,6 +278,7 @@ function direct_generic_model(
276278
Dict{Any,MOI.ConstraintIndex}(),
277279
false,
278280
Dict{Tuple{LineNumberNode,String},Float64}(),
281+
Dict{UInt64,MOI.ScalarNonlinearFunction}(),
279282
)
280283
end
281284

src/constraints.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,10 @@ function moi_function(constraint::AbstractConstraint)
769769
return moi_function(jump_function(constraint))
770770
end
771771

772+
function moi_function(model, constraint::AbstractConstraint)
773+
return moi_function(model, jump_function(constraint))
774+
end
775+
772776
"""
773777
moi_set(constraint::AbstractConstraint)
774778
@@ -1028,6 +1032,18 @@ function _moi_add_constraint(
10281032
return MOI.add_constraint(model, f, s)
10291033
end
10301034

1035+
function check_belongs_to_model(f::Vector, model)
1036+
for func in f
1037+
check_belongs_to_model(func, model)
1038+
end
1039+
return
1040+
end
1041+
1042+
function moi_function(model, f)
1043+
check_belongs_to_model(f, model)
1044+
return moi_function(f)
1045+
end
1046+
10311047
"""
10321048
add_constraint(
10331049
model::GenericModel,
@@ -1044,10 +1060,9 @@ function add_constraint(
10441060
name::String = "",
10451061
)
10461062
con = model_convert(model, con)
1063+
func, set = moi_function(model, con), moi_set(con)
10471064
# The type of backend(model) is unknown so we directly redirect to another
10481065
# function.
1049-
check_belongs_to_model(con, model)
1050-
func, set = moi_function(con), moi_set(con)
10511066
cindex = _moi_add_constraint(
10521067
backend(model),
10531068
func,

src/nlp_expr.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -596,18 +596,29 @@ moi_function(x::Number) = x
596596
# `moi_function(AbstractArray{<:AbstractVariableRef})`
597597
moi_function(x::AbstractArray) = moi_function.(x)
598598

599-
function moi_function(f::GenericNonlinearExpr{V}) where {V}
599+
function moi_function(model::GenericModel, f::GenericNonlinearExpr{V}) where {V}
600+
key = objectid(f)
601+
if haskey(model.subexpressions, key)
602+
return model.subexpressions[key]
603+
end
600604
ret = MOI.ScalarNonlinearFunction(f.head, similar(f.args))
601605
stack = Tuple{MOI.ScalarNonlinearFunction,Int,GenericNonlinearExpr{V}}[]
602606
for i in length(f.args):-1:1
603607
if f.args[i] isa GenericNonlinearExpr{V}
604608
push!(stack, (ret, i, f.args[i]))
609+
elseif f.args[i] isa AbstractJuMPScalar
610+
ret.args[i] = moi_function(model, f.args[i])
605611
else
606612
ret.args[i] = moi_function(f.args[i])
607613
end
608614
end
609615
while !isempty(stack)
610616
parent, i, arg = pop!(stack)
617+
arg_key = objectid(arg)
618+
if haskey(model.subexpressions, arg_key)
619+
parent.args[i] = model.subexpressions[arg_key]
620+
continue
621+
end
611622
child = MOI.ScalarNonlinearFunction(arg.head, similar(arg.args))
612623
parent.args[i] = child
613624
for j in length(arg.args):-1:1
@@ -617,7 +628,9 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V}
617628
child.args[j] = moi_function(arg.args[j])
618629
end
619630
end
631+
model.subexpressions[arg_key] = child
620632
end
633+
model.subexpressions[key] = ret
621634
return ret
622635
end
623636

@@ -1239,7 +1252,8 @@ function moi_function(f::AbstractVector{<:GenericNonlinearExpr})
12391252
end
12401253

12411254
function MOI.VectorNonlinearFunction(f::Vector{<:AbstractJuMPScalar})
1242-
return MOI.VectorNonlinearFunction(map(moi_function, f))
1255+
model = owner_model(first(f))
1256+
return MOI.VectorNonlinearFunction(moi_function.(model, f))
12431257
end
12441258

12451259
"""
@@ -1285,7 +1299,7 @@ x
12851299
```
12861300
"""
12871301
function simplify(model::GenericModel, f::AbstractJuMPScalar)
1288-
g = MOI.Nonlinear.SymbolicAD.simplify(moi_function(f))
1302+
g = MOI.Nonlinear.SymbolicAD.simplify(moi_function(model, f))
12891303
return jump_function(model, g)
12901304
end
12911305

@@ -1336,7 +1350,8 @@ function derivative(
13361350
f::AbstractJuMPScalar,
13371351
x::GenericVariableRef{T},
13381352
) where {T}
1339-
df_dx = MOI.Nonlinear.SymbolicAD.derivative(moi_function(f), index(x))
1353+
df_dx =
1354+
MOI.Nonlinear.SymbolicAD.derivative(moi_function(model, f), index(x))
13401355
return jump_function(model, MOI.Nonlinear.SymbolicAD.simplify!(df_dx))
13411356
end
13421357

@@ -1381,7 +1396,7 @@ julia> ∇f[y]
13811396
```
13821397
"""
13831398
function gradient(model::GenericModel{T}, f::AbstractJuMPScalar) where {T}
1384-
g = moi_function(f)
1399+
g = moi_function(model, f)
13851400
∇f = Dict{GenericVariableRef{T},Any}()
13861401
for xi in MOI.Nonlinear.SymbolicAD.variables(g)
13871402
df_dx = MOI.Nonlinear.SymbolicAD.simplify!(

src/objective.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ end
280280

281281
function set_objective_function(model::GenericModel, func::AbstractJuMPScalar)
282282
check_belongs_to_model(func, model)
283-
set_objective_function(model, moi_function(func))
283+
set_objective_function(model, moi_function(model, func))
284284
return
285285
end
286286

@@ -299,7 +299,7 @@ function set_objective_function(
299299
for f in func
300300
check_belongs_to_model(f, model)
301301
end
302-
set_objective_function(model, moi_function(func))
302+
set_objective_function(model, moi_function(model, func))
303303
return
304304
end
305305

0 commit comments

Comments
 (0)