Skip to content

Commit 9eb811b

Browse files
committed
address comments
1 parent a070be2 commit 9eb811b

5 files changed

Lines changed: 172 additions & 233 deletions

File tree

src/bridges.jl

Lines changed: 66 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -115,75 +115,85 @@ _square_offset(::MOI.AbstractSymmetricMatrixSetSquare) = 0
115115
_square_offset(::MOI.RootDetConeSquare) = 1
116116
_square_offset(::MOI.LogDetConeSquare) = 2
117117

118-
function _square_to_triangle_indices(
119-
bridge::MOI.Bridges.Constraint.SquareBridge,
120-
)
121-
s = bridge.square_set
122-
dim = MOI.side_dimension(s)
123-
offset = _square_offset(s)
124-
upper_triangle_indices = collect(1:offset)
125-
sizehint!(upper_triangle_indices, offset + div(dim * (dim + 1), 2))
126-
k = offset
127-
for j in 1:dim
128-
for i in 1:j
129-
k += 1
130-
push!(upper_triangle_indices, k)
131-
end
132-
k += dim - j
133-
end
134-
return upper_triangle_indices
135-
end
136-
137-
"""
138-
_triangle_to_square_scalars(tri_scalars, s)
139-
140-
Expand triangle-vectorized scalars to square column-major form, mirroring
141-
off-diagonal entries. `s` is the square set (e.g. `PositiveSemidefiniteConeSquare`).
142-
"""
143-
function _triangle_to_square_scalars(tri_scalars, s)
144-
dim = MOI.side_dimension(s)
145-
offset = _square_offset(s)
146-
square_dim = offset + dim * dim
147-
square = Vector{eltype(tri_scalars)}(undef, square_dim)
148-
for i in 1:offset
149-
square[i] = tri_scalars[i]
150-
end
151-
tri_k = offset
152-
for j in 1:dim
153-
for i in 1:j
154-
tri_k += 1
155-
ij = offset + i + (j - 1) * dim
156-
square[ij] = tri_scalars[tri_k]
157-
if i != j
158-
ji = offset + j + (i - 1) * dim
159-
square[ji] = tri_scalars[tri_k]
160-
end
161-
end
162-
end
163-
return square
164-
end
165-
118+
# Similar to `MOI.set` for `MOI.ConstraintPrimalStart` on `SquareBridge` in
119+
# MathOptInterface/src/Bridges/Constraint/bridges/SquareBridge.jl
166120
function MOI.set(
167121
model::MOI.ModelLike,
168122
attr::DiffOpt.ForwardConstraintFunction,
169123
bridge::MOI.Bridges.Constraint.SquareBridge{T},
170124
func::MOI.VectorAffineFunction{T},
171125
) where {T}
172-
indices = _square_to_triangle_indices(bridge)
173-
tri_func = MOI.Utilities.eachscalar(func)[indices]
174-
return MOI.set(model, attr, bridge.triangle, tri_func)
126+
dim = MOI.side_dimension(bridge.square_set)
127+
offset = _square_offset(bridge.square_set)
128+
scalars = MOI.Utilities.eachscalar(func)
129+
tri_scalars =
130+
Vector{eltype(scalars)}(undef, offset + div(dim * (dim + 1), 2))
131+
for i in 1:offset
132+
tri_scalars[i] = scalars[i]
133+
end
134+
k = offset
135+
for j in 1:dim, i in 1:j
136+
k += 1
137+
tri_scalars[k] = scalars[offset+j+(i-1)*dim]
138+
end
139+
MOI.set(
140+
model,
141+
attr,
142+
bridge.triangle,
143+
MOI.Utilities.operate(vcat, T, tri_scalars...),
144+
)
145+
for ((i, j), ci) in bridge.sym
146+
f_ij = scalars[offset+i+(j-1)*dim]
147+
f_ji = scalars[offset+j+(i-1)*dim]
148+
MOI.set(model, attr, ci, MOI.Utilities.operate(-, T, f_ij, f_ji))
149+
end
150+
return
175151
end
176152

153+
# Adjoint of `MOI.set` for `ForwardConstraintFunction` on `SquareBridge` above.
154+
# The forward map extracts upper triangle and sym diffs; this is its transpose.
155+
# Similar structure to `MOI.get` for `MOI.ConstraintPrimal` on `SquareBridge` in
156+
# MathOptInterface/src/Bridges/Constraint/bridges/SquareBridge.jl
177157
function MOI.get(
178158
model::MOI.ModelLike,
179159
attr::DiffOpt.ReverseConstraintFunction,
180160
bridge::MOI.Bridges.Constraint.SquareBridge{T},
181161
) where {T}
182-
tri_func_raw = MOI.get(model, attr, bridge.triangle)
183-
tri_func = DiffOpt.standard_form(tri_func_raw)
184-
tri_scalars = MOI.Utilities.eachscalar(tri_func)
185-
square_scalars = _triangle_to_square_scalars(tri_scalars, bridge.square_set)
186-
return MOI.Utilities.operate(vcat, T, square_scalars...)
162+
tri_func = DiffOpt.standard_form(MOI.get(model, attr, bridge.triangle))
163+
tri = MOI.Utilities.eachscalar(tri_func)
164+
dim = MOI.side_dimension(bridge.square_set)
165+
offset = _square_offset(bridge.square_set)
166+
square = Vector{eltype(tri)}(undef, offset + dim^2)
167+
for i in 1:offset
168+
square[i] = tri[i]
169+
end
170+
k = offset
171+
sym_index = 1
172+
for j in 1:dim, i in 1:j
173+
k += 1
174+
upper_index = offset + i + (j - 1) * dim
175+
lower_index = offset + j + (i - 1) * dim
176+
if i == j
177+
square[upper_index] = tri[k]
178+
elseif sym_index <= length(bridge.sym) &&
179+
bridge.sym[sym_index].first == (i, j)
180+
π = DiffOpt.standard_form(
181+
MOI.get(model, attr, bridge.sym[sym_index].second),
182+
)
183+
square[upper_index] = MOI.Utilities.operate(
184+
+,
185+
T,
186+
MOI.Utilities.operate(+, T, tri[k], tri[k]),
187+
π,
188+
)
189+
square[lower_index] = MOI.Utilities.operate(-, T, π)
190+
sym_index += 1
191+
else
192+
square[upper_index] = tri[k]
193+
square[lower_index] = tri[k]
194+
end
195+
end
196+
return MOI.Utilities.operate(vcat, T, square...)
187197
end
188198

189199
function _variable_to_index_map(bridge)

src/jump_wrapper.jl

Lines changed: 12 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -234,99 +234,41 @@ function set_forward_constraint_function(
234234
},
235235
value::Number,
236236
) where {M}
237-
JuMP.check_belongs_to_model(con_ref, model)
238-
return MOI.set(
239-
JuMP.backend(model),
240-
ForwardConstraintFunction(),
241-
JuMP.index(con_ref),
242-
JuMP.moi_function(JuMP.AffExpr(value)),
243-
)
237+
return set_forward_constraint_function(model, con_ref, JuMP.AffExpr(value))
244238
end
245239

240+
# Similar to `JuMP.set_start_value` for vector `ConstraintRef` in
241+
# JuMP/src/constraints.jl
246242
function set_forward_constraint_function(
247243
model::JuMP.Model,
248244
con_ref::JuMP.ConstraintRef{
249-
M,
245+
<:JuMP.AbstractModel,
250246
<:MOI.ConstraintIndex{<:MOI.AbstractVectorFunction},
251247
},
252248
value::AbstractArray{<:JuMP.AbstractJuMPScalar},
253-
) where {M}
249+
)
254250
JuMP.check_belongs_to_model(con_ref, model)
255251
JuMP.check_belongs_to_model.(value, model)
252+
v = JuMP.vectorize(value, con_ref.shape)
256253
return MOI.set(
257254
JuMP.backend(model),
258255
ForwardConstraintFunction(),
259256
JuMP.index(con_ref),
260-
JuMP.moi_function(value),
261-
)
262-
end
263-
264-
function set_forward_constraint_function(
265-
model::JuMP.Model,
266-
con_ref::JuMP.ConstraintRef{
267-
M,
268-
<:MOI.ConstraintIndex{<:MOI.AbstractVectorFunction},
269-
},
270-
value::AbstractArray{<:Number},
271-
) where {M}
272-
JuMP.check_belongs_to_model(con_ref, model)
273-
return MOI.set(
274-
JuMP.backend(model),
275-
ForwardConstraintFunction(),
276-
JuMP.index(con_ref),
277-
JuMP.moi_function(JuMP.AffExpr.(value)),
257+
JuMP.moi_function(v),
278258
)
279259
end
280260

261+
# Similar to `JuMP.set_start_value` for vector `ConstraintRef` in
262+
# JuMP/src/constraints.jl
281263
function set_forward_constraint_function(
282264
model::JuMP.Model,
283265
con_ref::JuMP.ConstraintRef{
284266
<:JuMP.AbstractModel,
285267
<:MOI.ConstraintIndex{<:MOI.AbstractVectorFunction},
286-
S,
287268
},
288-
value::AbstractMatrix{<:Number},
289-
) where {S<:Union{JuMP.SquareMatrixShape,JuMP.SymmetricMatrixShape}}
290-
if !LinearAlgebra.issymmetric(value)
291-
error(
292-
"ForwardConstraintFunction perturbation matrix must be " *
293-
"symmetric for PSD cone constraints.",
294-
)
295-
end
296-
JuMP.check_belongs_to_model(con_ref, model)
297-
v = JuMP.vectorize(value, con_ref.shape)
298-
func = JuMP.moi_function(JuMP.AffExpr.(v))
299-
MOI.set(
300-
JuMP.backend(model),
301-
ForwardConstraintFunction(),
302-
JuMP.index(con_ref),
303-
func,
304-
)
305-
return
306-
end
307-
308-
function set_forward_constraint_function(
309-
model::JuMP.Model,
310-
con_ref::JuMP.ConstraintRef{<:JuMP.AbstractModel,<:MOI.ConstraintIndex,S},
311-
value::AbstractMatrix{<:JuMP.AbstractJuMPScalar},
312-
) where {S<:Union{JuMP.SquareMatrixShape,JuMP.SymmetricMatrixShape}}
313-
if !LinearAlgebra.issymmetric(value)
314-
error(
315-
"ForwardConstraintFunction perturbation matrix must be " *
316-
"symmetric for PSD cone constraints.",
317-
)
318-
end
319-
JuMP.check_belongs_to_model(con_ref, model)
320-
JuMP.check_belongs_to_model.(value, model)
321-
v = JuMP.vectorize(value, con_ref.shape)
322-
func = JuMP.moi_function(v)
323-
MOI.set(
324-
JuMP.backend(model),
325-
ForwardConstraintFunction(),
326-
JuMP.index(con_ref),
327-
func,
328-
)
329-
return
269+
value::AbstractArray{<:Number},
270+
)
271+
return set_forward_constraint_function(model, con_ref, JuMP.AffExpr.(value))
330272
end
331273

332274
"""

test/bridges.jl

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -104,77 +104,6 @@ function test_dU_from_dQ()
104104
return _test_dU_dQ(U, dU)
105105
end
106106

107-
function _make_square_bridge(dim)
108-
s = MOI.PositiveSemidefiniteConeSquare(dim)
109-
return MOI.Bridges.Constraint.SquareBridge{
110-
Float64,
111-
MOI.VectorAffineFunction{Float64},
112-
MOI.ScalarAffineFunction{Float64},
113-
MOI.PositiveSemidefiniteConeTriangle,
114-
MOI.PositiveSemidefiniteConeSquare,
115-
}(
116-
s,
117-
MOI.ConstraintIndex{
118-
MOI.VectorAffineFunction{Float64},
119-
MOI.PositiveSemidefiniteConeTriangle,
120-
}(
121-
1,
122-
),
123-
Pair{
124-
Tuple{Int,Int},
125-
MOI.ConstraintIndex{
126-
MOI.ScalarAffineFunction{Float64},
127-
MOI.EqualTo{Float64},
128-
},
129-
}[],
130-
)
131-
end
132-
133-
function test_square_to_triangle_indices()
134-
# 2x2: square col-major [a11, a21, a12, a22] → upper tri [a11, a12, a22]
135-
@test DiffOpt._square_to_triangle_indices(_make_square_bridge(2)) ==
136-
[1, 3, 4]
137-
# 3x3: square col-major [a11,a21,a31, a12,a22,a32, a13,a23,a33]
138-
# → upper tri [a11, a12,a22, a13,a23,a33] at indices [1, 4,5, 7,8,9]
139-
@test DiffOpt._square_to_triangle_indices(_make_square_bridge(3)) ==
140-
[1, 4, 5, 7, 8, 9]
141-
# 1x1: trivial
142-
@test DiffOpt._square_to_triangle_indices(_make_square_bridge(1)) == [1]
143-
end
144-
145-
function _make_rootdet_square_bridge(dim)
146-
s = MOI.RootDetConeSquare(dim)
147-
return MOI.Bridges.Constraint.SquareBridge{
148-
Float64,
149-
MOI.VectorAffineFunction{Float64},
150-
MOI.ScalarAffineFunction{Float64},
151-
MOI.RootDetConeTriangle,
152-
MOI.RootDetConeSquare,
153-
}(
154-
s,
155-
MOI.ConstraintIndex{
156-
MOI.VectorAffineFunction{Float64},
157-
MOI.RootDetConeTriangle,
158-
}(
159-
1,
160-
),
161-
Pair{
162-
Tuple{Int,Int},
163-
MOI.ConstraintIndex{
164-
MOI.ScalarAffineFunction{Float64},
165-
MOI.EqualTo{Float64},
166-
},
167-
}[],
168-
)
169-
end
170-
171-
function test_square_to_triangle_indices_with_offset()
172-
# RootDetConeSquare(2) has 1 offset entry (the rootdet variable t)
173-
# Full square: [t, a11, a21, a12, a22] → triangle: [t, a11, a12, a22]
174-
bridge = _make_rootdet_square_bridge(2)
175-
@test DiffOpt._square_to_triangle_indices(bridge) == [1, 2, 4, 5]
176-
end
177-
178107
function test_square_offset()
179108
@test DiffOpt._square_offset(MOI.PositiveSemidefiniteConeSquare(2)) == 0
180109
@test DiffOpt._square_offset(MOI.RootDetConeSquare(2)) == 1
@@ -187,27 +116,6 @@ function test_square_offset()
187116
[1, 2]
188117
end
189118

190-
function test_triangle_to_square_scalars()
191-
# 2x2 PSD: triangle [a11, a12, a22] → square [a11, a12, a12, a22]
192-
s = MOI.PositiveSemidefiniteConeSquare(2)
193-
@test DiffOpt._triangle_to_square_scalars([1, 2, 3], s) == [1, 2, 2, 3]
194-
# 3x3 PSD: triangle [a11, a12, a22, a13, a23, a33]
195-
# → square col-major [a11, a12, a13, a12, a22, a23, a13, a23, a33]
196-
s3 = MOI.PositiveSemidefiniteConeSquare(3)
197-
@test DiffOpt._triangle_to_square_scalars([1, 2, 3, 4, 5, 6], s3) ==
198-
[1, 2, 4, 2, 3, 5, 4, 5, 6]
199-
# RootDetConeSquare(2): triangle [t, a11, a12, a22]
200-
# → square [t, a11, a12, a12, a22]
201-
sr = MOI.RootDetConeSquare(2)
202-
@test DiffOpt._triangle_to_square_scalars([10, 1, 2, 3], sr) ==
203-
[10, 1, 2, 2, 3]
204-
# LogDetConeSquare(2): triangle [u, t, a11, a12, a22]
205-
# → square [u, t, a11, a12, a12, a22]
206-
sl = MOI.LogDetConeSquare(2)
207-
@test DiffOpt._triangle_to_square_scalars([10, 20, 1, 2, 3], sl) ==
208-
[10, 20, 1, 2, 2, 3]
209-
end
210-
211119
end
212120

213121
TestBridges.runtests()

0 commit comments

Comments
 (0)