Skip to content

Commit 659c9ff

Browse files
committed
First draft of BatchExaModel
1 parent 72bacec commit 659c9ff

2 files changed

Lines changed: 276 additions & 1 deletion

File tree

src/ExaModels.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ include("hessian.jl")
3737
include("nlp.jl")
3838
include("utils.jl")
3939
include("two_stage.jl")
40+
include("batch.jl")
4041

4142
export ExaModel,
4243
ExaCore,
@@ -58,6 +59,7 @@ export ExaModel,
5859
multipliers_U,
5960
@register_univariate,
6061
@register_bivariate,
61-
TwoStageExaModel
62+
TwoStageExaModel,
63+
BatchExaModel
6264

6365
end # module ExaModels

src/batch.jl

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""
2+
BatchExaModel{T, VT, M}
3+
4+
Parametric optimization model where multiple scenarios are fused into a single
5+
ExaModel and evaluated simultaneously using a shared compiled expression pattern.
6+
7+
All scenarios must share identical sparsity structures for Jacobians and Hessians
8+
independently of the parameter values. The model builder receives all variables
9+
and parameters and should rely on generators iterating over scenario data.
10+
11+
# Dimensions
12+
13+
- `ns`: number of scenarios
14+
- `nv`: number of variables per scenario
15+
- `nc`: number of constraints per scenario
16+
- `np`: number of parameters per scenario
17+
18+
# Layout
19+
20+
- Variables: [v₁; v₂; …; vₙₛ]
21+
- Constraints: [c₁; c₂; …; cₙₛ]
22+
23+
# Fields
24+
25+
- `model::M` : fused ExaModel containing all scenarios
26+
- `ns::Int` : number of scenarios
27+
- `np::Int` : number of parameters per scenario
28+
"""
29+
struct BatchExaModel{T, VT <: AbstractVector{T}, M <: ExaModel{T, VT}} <: NLPModels.AbstractNLPModel{T,VT}
30+
model::M
31+
ns::Int
32+
np::Int
33+
end
34+
35+
function Base.show(io::IO, m::BatchExaModel{T, VT}) where {T, VT}
36+
println(io, "BatchExaModel{$T, $VT}")
37+
println(io, " Number of scenarios: $(m.ns)")
38+
println(io, " Number of parameter per scenario: $(m.np)")
39+
Base.show(m.model)
40+
return
41+
end
42+
43+
# ============================================================================
44+
# Constructor
45+
# ============================================================================
46+
47+
"""
48+
BatchExaModel(build, nd, nv, ns, θ_sets; backend=nothing)
49+
50+
Build a batch model where all scenarios are fused into a single ExaModel.
51+
52+
All scenarios share ONE compiled expression pattern, achieving maximum GPU efficiency.
53+
This requires scenarios to have identical structure.
54+
55+
# Arguments
56+
- `build::Function`: Function `(c, d, v, θ, ns, nv, nθ) -> nothing`
57+
- `c`: ExaCore
58+
- `d`: Variable handle for design variables (indices 1:nd)
59+
- `v`: Variable handle for ALL recourse variables (indices 1:ns*nv)
60+
Scenario i's vars are at indices (i-1)*nv+1 : i*nv
61+
- `θ`: Parameter handle for ALL parameters (length ns*nθ)
62+
Scenario i's params are at indices (i-1)*nθ+1 : i*nθ
63+
- `ns, nv, nθ`: dimensions for building iteration data
64+
- `nd::Int`: Number of design variables
65+
- `nv::Int`: Number of recourse variables per scenario
66+
- `ns::Int`: Number of scenarios
67+
- `θ_sets::Vector{<:AbstractVector}`: Parameter vectors for each scenario
68+
69+
# Keyword Arguments
70+
- `backend`: Backend for computation (default: `nothing`)
71+
- `d_start`: Initial values for design variables (scalar or vector of length `nd`, default: `0.0`)
72+
- `d_lvar`: Lower bounds for design variables (scalar or vector of length `nd`, default: `-Inf`)
73+
- `d_uvar`: Upper bounds for design variables (scalar or vector of length `nd`, default: `Inf`)
74+
- `v_start`: Initial values for recourse variables (scalar or vector of length `ns*nv`, default: `0.0`)
75+
- `v_lvar`: Lower bounds for recourse variables (scalar or vector of length `ns*nv`, default: `-Inf`)
76+
- `v_uvar`: Upper bounds for recourse variables (scalar or vector of length `ns*nv`, default: `Inf`)
77+
78+
# Example
79+
```julia
80+
ns, nv, nd, nθ = 100, 5, 2, 3
81+
θ_sets = [rand(nθ) for _ in 1:ns]
82+
83+
model = BatchExaModel(nd, nv, ns, θ_sets) do c, d, v, θ, ns, nv, nθ
84+
obj_data = [(i, j, (i-1)*nv + j, (i-1)*nθ) for i in 1:ns for j in 1:nv]
85+
objective(c, θ[θ_off + 1] * v[v_idx]^2 for (i, j, v_idx, θ_off) in obj_data)
86+
87+
con_data = [(i, j, (i-1)*nv + j, (i-1)*nθ) for i in 1:ns for j in 1:nv]
88+
constraint(c, v[v_idx] + d[1] - θ[θ_off + 3] for (i, j, v_idx, θ_off) in con_data)
89+
end
90+
```
91+
"""
92+
function BatchExaModel(
93+
build::Function,
94+
nd::Int,
95+
nv::Int,
96+
ns::Int,
97+
θ_sets::Vector{<:AbstractVector};
98+
backend = nothing,
99+
d_start = 0.0,
100+
d_lvar = -Inf,
101+
d_uvar = Inf,
102+
v_start = 0.0,
103+
v_lvar = -Inf,
104+
v_uvar = Inf
105+
)
106+
length(θ_sets) == ns || throw(ArgumentError("θ_sets must have length ns=$ns"))
107+
= length(θ_sets[1])
108+
all(length(θ) ==for θ in θ_sets) || throw(ArgumentError("All θ_sets must have same length"))
109+
110+
c = ExaCore(; backend = backend)
111+
112+
# All recourse vars as one block, all params as one vector
113+
v = variable(c, ns * nv; start = v_start, lvar = v_lvar, uvar = v_uvar)
114+
d = variable(c, nd; start = d_start, lvar = d_lvar, uvar = d_uvar)
115+
θ_flat = reduce(vcat, θ_sets)
116+
θ = parameter(c, θ_flat)
117+
118+
nc_before = c.ncon
119+
build(c, d, v, θ, ns, nv, nθ)
120+
121+
model = ExaModel(c)
122+
123+
# Calculate nnz per scenario
124+
total_nnzj = NLPModels.get_nnzj(model)
125+
total_nnzh = NLPModels.get_nnzh(model)
126+
nnzj = total_nnzj ÷ ns
127+
nnzh= total_nnzh ÷ ns
128+
129+
T = eltype(c.x0)
130+
VT = typeof(c.x0)
131+
132+
return BatchExaModel{T, VT, typeof(model)}(
133+
model, ns, nv, nd, nc, nθ, nnzj_per_scenario, nnzh_per_scenario
134+
)
135+
end
136+
137+
# ============================================================================
138+
# Full Evaluation (Single Kernel Launch)
139+
# ============================================================================
140+
141+
"""
142+
obj(model::BatchExaModel, x_global)
143+
144+
Evaluate all objectives.
145+
Output: obj_global ∈ ℝ^{ns}
146+
"""
147+
function obj(model::BatchExaModel, x_global::AbstractVector)
148+
return obj(model.model, x_global, obj_global)
149+
end
150+
151+
"""
152+
cons!(model::BatchExaModel, x_global, c_global)
153+
154+
Evaluate all constraints.
155+
Output: c_global ∈ ℝ^{ns*nc}
156+
"""
157+
function cons!(
158+
model::BatchExaModel,
159+
x_global::AbstractVector,
160+
c_global::AbstractVector
161+
)
162+
cons!(model.model, x_global, c_global)
163+
return c_global
164+
end
165+
166+
"""
167+
grad!(model::BatchExaModel, x_global, g_global)
168+
169+
Evaluate all gradients.
170+
Output: g_global ∈ ℝ^{ns*nv}
171+
"""
172+
function grad!(
173+
model::BatchExaModel,
174+
x_global::AbstractVector,
175+
g_global::AbstractVector
176+
)
177+
grad!(model.model, x_global, g_global)
178+
return g_global
179+
end
180+
181+
"""
182+
jac_coord!(model::BatchExaModel, x_global, jac_global)
183+
184+
Evaluate all Jacobians (COO format).
185+
Output: jac_global ∈ ℝ^{ns*nnzj}
186+
"""
187+
function jac_coord!(
188+
model::BatchExaModel,
189+
x_global::AbstractVector,
190+
jac_global::AbstractVector
191+
)
192+
jac_coord!(model.model, x_global, jac_global)
193+
return jac_global
194+
end
195+
196+
"""
197+
jac_structure!(model::BatchExaModel, jrows, jcols)
198+
199+
Get the common sparsity pattern of the Jacobian.
200+
Output: jrows ∈ ℝ^{ns*nnzj} and jcols ∈ ℝ^{ns*nnzj}
201+
"""
202+
function jac_structure!(
203+
model::BatchExaModel,
204+
jrows::AbstractVector{<:Integer},
205+
jcols::AbstractVector{<:Integer}
206+
)
207+
jac_structure!(model.model, jrows, jcols)
208+
return jrows, jcols
209+
end
210+
211+
"""
212+
hess_coord!(model::BatchExaModel, x_global, y_global, hess_global; obj_weight=1.0)
213+
214+
Evaluate all Hessians of the Lagrangian (COO format).
215+
Output: hess_global ∈ ℝ^{ns*nnzh}
216+
"""
217+
function hess_coord!(
218+
model::BatchExaModel,
219+
x_global::AbstractVector,
220+
y_global::AbstractVector,
221+
hess_global::AbstractVector;
222+
obj_weight = one(eltype(x_global))
223+
)
224+
hess_coord!(model.model, x_global, y_global, hess_global; obj_weight = obj_weight)
225+
return hess_global
226+
end
227+
228+
"""
229+
hess_structure!(model::BatchExaModel, hrows, hcols)
230+
231+
Get the common sparsity pattern of the Hessian of the Lagrangian.
232+
Output: hrows ∈ ℝ^{ns*nnzh} and hcols ∈ ℝ^{ns*nnzh}
233+
"""
234+
function hess_structure!(
235+
model::BatchExaModel,
236+
hrows::AbstractVector{<:Integer},
237+
hcols::AbstractVector{<:Integer}
238+
)
239+
hess_structure!(model.model, hrows, hcols)
240+
return hrows, hcols
241+
end
242+
243+
# ============================================================================
244+
# NLPModels Interface
245+
# ============================================================================
246+
247+
"""
248+
get_nnzj(model::BatchExaModel)
249+
250+
Total number of Jacobian nonzeros.
251+
"""
252+
NLPModels.get_nnzj(model::BatchExaModel) = NLPModels.get_nnzj(model.model)
253+
254+
"""
255+
get_nnzh(model::BatchExaModel)
256+
257+
Total number of Hessian nonzeros.
258+
"""
259+
NLPModels.get_nnzh(model::BatchExaModel) = NLPModels.get_nnzh(model.model)
260+
261+
"""
262+
get_nvar(model::BatchExaModel)
263+
264+
Total number of variables.
265+
"""
266+
NLPModels.get_nvar(model::BatchExaModel) = NLPModels.get_nvar(model.model)
267+
268+
"""
269+
get_ncon(model::BatchExaModel)
270+
271+
Total number of constraints.
272+
"""
273+
NLPModels.get_ncon(model::BatchExaModel) = NLPModels.get_ncon(model.model)

0 commit comments

Comments
 (0)