Skip to content

Commit 393da2f

Browse files
authored
feat: add m.copy() method to create deep copy of model (#623)
* Added m.copy() method. * Added testing suite for m.copy(). * Fix solver_dir type annotation. * Bug fix: xarray copyies need to be . * Moved copy to io.py, added deep-copy to all xarray operations. * Improved copy method: Strengtheninc copy protocol compatibility, check for deep copy independence. * Added release notes. * Made Model.copy defaulting to deep copy more explicit. * Fine-tuned docs and added to read the docs api.rst.
1 parent 474f79b commit 393da2f

File tree

5 files changed

+304
-2
lines changed

5 files changed

+304
-2
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Creating a model
2424
piecewise.segments
2525
model.Model.linexpr
2626
model.Model.remove_constraints
27+
model.Model.copy
2728

2829

2930
Classes under the hook

doc/release_notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Release Notes
44
Upcoming Version
55
----------------
66

7+
* Add ``Model.copy()`` (default deep copy) with ``deep`` and ``include_solution`` options; support Python ``copy.copy`` and ``copy.deepcopy`` protocols via ``__copy__`` and ``__deepcopy__``.
78
* Harmonize coordinate alignment for operations with subset/superset objects:
89
- Multiplication and division fill missing coords with 0 (variable doesn't participate)
910
- Addition and subtraction of constants fill missing coords with 0 (identity element) and pin result to LHS coords

linopy/io.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,3 +1239,124 @@ def get_prefix(ds: xr.Dataset, prefix: str) -> xr.Dataset:
12391239
setattr(m, k, ds.attrs.get(k))
12401240

12411241
return m
1242+
1243+
1244+
def copy(m: Model, include_solution: bool = False, deep: bool = True) -> Model:
1245+
"""
1246+
Return a copy of this model.
1247+
1248+
With ``deep=True`` (default), variables, constraints, objective,
1249+
parameters, blocks, and scalar attributes are copied to a fully
1250+
independent model. With ``deep=False``, returns a shallow copy.
1251+
1252+
:meth:`Model.copy` defaults to deep copy for workflow safety.
1253+
In contrast, ``copy.copy(model)`` is shallow via ``__copy__``, and
1254+
``copy.deepcopy(model)`` is deep via ``__deepcopy__``.
1255+
1256+
Solver runtime metadata (for example, ``solver_name`` and
1257+
``solver_model``) is intentionally not copied. Solver backend state
1258+
is recreated on ``solve()``.
1259+
1260+
Parameters
1261+
----------
1262+
m : Model
1263+
The model to copy.
1264+
include_solution : bool, optional
1265+
Whether to include solution and dual values in the copy.
1266+
If False (default), solve artifacts are excluded: solution/dual data,
1267+
objective value, and solve status are reset to initialized state.
1268+
If True, these values are copied when present. For unsolved models,
1269+
this has no additional effect.
1270+
deep : bool, optional
1271+
Whether to return a deep copy (default) or shallow copy. If False,
1272+
the returned model uses independent wrapper objects that share
1273+
underlying data buffers with the source model.
1274+
1275+
Returns
1276+
-------
1277+
Model
1278+
A deep or shallow copy of the model.
1279+
"""
1280+
from linopy.model import (
1281+
Constraint,
1282+
Constraints,
1283+
LinearExpression,
1284+
Model,
1285+
Objective,
1286+
Variable,
1287+
Variables,
1288+
)
1289+
1290+
SOLVE_STATE_ATTRS = {"status", "termination_condition"}
1291+
1292+
new_model = Model(
1293+
chunk=m._chunk,
1294+
force_dim_names=m._force_dim_names,
1295+
auto_mask=m._auto_mask,
1296+
solver_dir=str(m._solver_dir),
1297+
)
1298+
1299+
new_model._variables = Variables(
1300+
{
1301+
name: Variable(
1302+
var.data.copy(deep=deep)
1303+
if include_solution
1304+
else var.data[m.variables.dataset_attrs].copy(deep=deep),
1305+
new_model,
1306+
name,
1307+
)
1308+
for name, var in m.variables.items()
1309+
},
1310+
new_model,
1311+
)
1312+
1313+
new_model._constraints = Constraints(
1314+
{
1315+
name: Constraint(
1316+
con.data.copy(deep=deep)
1317+
if include_solution
1318+
else con.data[m.constraints.dataset_attrs].copy(deep=deep),
1319+
new_model,
1320+
name,
1321+
)
1322+
for name, con in m.constraints.items()
1323+
},
1324+
new_model,
1325+
)
1326+
1327+
obj_expr = LinearExpression(m.objective.expression.data.copy(deep=deep), new_model)
1328+
new_model._objective = Objective(obj_expr, new_model, m.objective.sense)
1329+
new_model._objective._value = m.objective.value if include_solution else None
1330+
1331+
new_model._parameters = m._parameters.copy(deep=deep)
1332+
new_model._blocks = m._blocks.copy(deep=deep) if m._blocks is not None else None
1333+
1334+
for attr in m.scalar_attrs:
1335+
if include_solution or attr not in SOLVE_STATE_ATTRS:
1336+
setattr(new_model, attr, getattr(m, attr))
1337+
1338+
return new_model
1339+
1340+
1341+
def shallowcopy(m: Model) -> Model:
1342+
"""
1343+
Support Python's ``copy.copy`` protocol for ``Model``.
1344+
1345+
Returns a shallow copy with independent wrapper objects that share
1346+
underlying array buffers with ``m``. Solve artifacts are excluded,
1347+
matching :meth:`Model.copy` defaults.
1348+
"""
1349+
return copy(m, include_solution=False, deep=False)
1350+
1351+
1352+
def deepcopy(m: Model, memo: dict[int, Any]) -> Model:
1353+
"""
1354+
Support Python's ``copy.deepcopy`` protocol for ``Model``.
1355+
1356+
Returns a deep, structurally independent copy and records it in ``memo``
1357+
as required by Python's copy protocol. Solve artifacts are excluded,
1358+
matching :meth:`Model.copy` defaults.
1359+
"""
1360+
new_model = copy(m, include_solution=False, deep=True)
1361+
memo[id(m)] = new_model
1362+
return new_model

linopy/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
ScalarLinearExpression,
5454
)
5555
from linopy.io import (
56+
copy,
57+
deepcopy,
58+
shallowcopy,
5659
to_block_files,
5760
to_cupdlpx,
5861
to_file,
@@ -1877,6 +1880,12 @@ def reset_solution(self) -> None:
18771880
self.variables.reset_solution()
18781881
self.constraints.reset_dual()
18791882

1883+
copy = copy
1884+
1885+
__copy__ = shallowcopy
1886+
1887+
__deepcopy__ = deepcopy
1888+
18801889
to_netcdf = to_netcdf
18811890

18821891
to_file = to_file

test/test_model.py

Lines changed: 172 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,21 @@
55

66
from __future__ import annotations
77

8+
import copy as pycopy
89
from pathlib import Path
910
from tempfile import gettempdir
1011

1112
import numpy as np
1213
import pytest
1314
import xarray as xr
1415

15-
from linopy import EQUAL, Model
16-
from linopy.testing import assert_model_equal
16+
from linopy import EQUAL, Model, available_solvers
17+
from linopy.testing import (
18+
assert_conequal,
19+
assert_equal,
20+
assert_linequal,
21+
assert_model_equal,
22+
)
1723

1824
target_shape: tuple[int, int] = (10, 10)
1925

@@ -163,3 +169,167 @@ def test_assert_model_equal() -> None:
163169
m.add_objective(obj)
164170

165171
assert_model_equal(m, m)
172+
173+
174+
@pytest.fixture(scope="module")
175+
def copy_test_model() -> Model:
176+
"""Small representative model used across copy tests."""
177+
m: Model = Model()
178+
179+
lower: xr.DataArray = xr.DataArray(
180+
np.zeros((10, 10)), coords=[range(10), range(10)]
181+
)
182+
upper: xr.DataArray = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)])
183+
x = m.add_variables(lower, upper, name="x")
184+
y = m.add_variables(name="y")
185+
186+
m.add_constraints(1 * x + 10 * y, EQUAL, 0)
187+
m.add_objective((10 * x + 5 * y).sum())
188+
189+
return m
190+
191+
192+
@pytest.fixture(scope="module")
193+
def solved_copy_test_model(copy_test_model: Model) -> Model:
194+
"""Solved representative model used across solved-copy tests."""
195+
m = copy_test_model.copy(deep=True)
196+
m.solve()
197+
return m
198+
199+
200+
def test_model_copy_unsolved(copy_test_model: Model) -> None:
201+
"""Copy of unsolved model is structurally equal and independent."""
202+
m = copy_test_model.copy(deep=True)
203+
c = m.copy(include_solution=False)
204+
205+
assert_model_equal(m, c)
206+
207+
# independence: mutating copy does not affect source
208+
c.add_variables(name="z")
209+
assert "z" not in m.variables
210+
211+
212+
def test_model_copy_unsolved_with_solution_flag(copy_test_model: Model) -> None:
213+
"""Unsolved model with include_solution=True has no extra solve artifacts."""
214+
m = copy_test_model.copy(deep=True)
215+
216+
c_include_solution = m.copy(include_solution=True)
217+
c_exclude_solution = m.copy(include_solution=False)
218+
219+
assert_model_equal(c_include_solution, c_exclude_solution)
220+
assert c_include_solution.status == "initialized"
221+
assert c_include_solution.termination_condition == ""
222+
assert c_include_solution.objective.value is None
223+
224+
225+
def test_model_copy_shallow(copy_test_model: Model) -> None:
226+
"""Shallow copy has independent wrappers sharing underlying data buffers."""
227+
m = copy_test_model.copy(deep=True)
228+
c = m.copy(deep=False)
229+
230+
assert c is not m
231+
assert c.variables is not m.variables
232+
assert c.constraints is not m.constraints
233+
assert c.objective is not m.objective
234+
235+
# wrappers are distinct, but shallow copy shares payload buffers
236+
c.variables["x"].lower.values[0, 0] = 123.0
237+
assert m.variables["x"].lower.values[0, 0] == 123.0
238+
239+
240+
def test_model_deepcopy_protocol(copy_test_model: Model) -> None:
241+
"""copy.deepcopy(model) dispatches to Model.__deepcopy__ and stays independent."""
242+
m = copy_test_model.copy(deep=True)
243+
c = pycopy.deepcopy(m)
244+
245+
assert_model_equal(m, c)
246+
247+
# Test independence: mutations to copy do not affect source
248+
# 1. Variable mutation: add new variable
249+
c.add_variables(name="z")
250+
assert "z" not in m.variables
251+
252+
# 2. Variable data mutation (bounds): verify buffers are independent
253+
original_lower = m.variables["x"].lower.values[0, 0].item()
254+
new_lower = 999
255+
c.variables["x"].lower.values[0, 0] = new_lower
256+
assert c.variables["x"].lower.values[0, 0] == new_lower
257+
assert m.variables["x"].lower.values[0, 0] == original_lower
258+
259+
# 3. Constraint coefficient mutation: deep copy must not leak back
260+
original_con_coeff = m.constraints["con0"].coeffs.values.flat[0].item()
261+
new_con_coeff = original_con_coeff + 42
262+
c.constraints["con0"].coeffs.values.flat[0] = new_con_coeff
263+
assert c.constraints["con0"].coeffs.values.flat[0] == new_con_coeff
264+
assert m.constraints["con0"].coeffs.values.flat[0] == original_con_coeff
265+
266+
# 4. Objective expression coefficient mutation: deep copy must not leak back
267+
original_obj_coeff = m.objective.expression.coeffs.values.flat[0].item()
268+
new_obj_coeff = original_obj_coeff + 20
269+
c.objective.expression.coeffs.values.flat[0] = new_obj_coeff
270+
assert c.objective.expression.coeffs.values.flat[0] == new_obj_coeff
271+
assert m.objective.expression.coeffs.values.flat[0] == original_obj_coeff
272+
273+
# 5. Objective sense mutation
274+
original_sense = m.objective.sense
275+
c.objective.sense = "max"
276+
assert c.objective.sense == "max"
277+
assert m.objective.sense == original_sense
278+
279+
280+
@pytest.mark.skipif(not available_solvers, reason="No solver installed")
281+
class TestModelCopySolved:
282+
def test_model_deepcopy_protocol_excludes_solution(
283+
self, solved_copy_test_model: Model
284+
) -> None:
285+
"""copy.deepcopy on solved model drops solve state by default."""
286+
m = solved_copy_test_model
287+
288+
c = pycopy.deepcopy(m)
289+
290+
assert c.status == "initialized"
291+
assert c.termination_condition == ""
292+
assert c.objective.value is None
293+
294+
for v in m.variables:
295+
assert_equal(
296+
c.variables[v].data[c.variables.dataset_attrs],
297+
m.variables[v].data[m.variables.dataset_attrs],
298+
)
299+
for con in m.constraints:
300+
assert_conequal(c.constraints[con], m.constraints[con], strict=False)
301+
assert_linequal(c.objective.expression, m.objective.expression)
302+
assert c.objective.sense == m.objective.sense
303+
304+
def test_model_copy_solved_with_solution(
305+
self, solved_copy_test_model: Model
306+
) -> None:
307+
"""Copy with include_solution=True preserves solve state."""
308+
m = solved_copy_test_model
309+
310+
c = m.copy(include_solution=True)
311+
assert_model_equal(m, c)
312+
313+
def test_model_copy_solved_without_solution(
314+
self, solved_copy_test_model: Model
315+
) -> None:
316+
"""Copy with include_solution=False (default) drops solve state but preserves problem structure."""
317+
m = solved_copy_test_model
318+
319+
c = m.copy(include_solution=False)
320+
321+
# solve state is dropped
322+
assert c.status == "initialized"
323+
assert c.termination_condition == ""
324+
assert c.objective.value is None
325+
326+
# problem structure is preserved — compare only dataset_attrs to exclude solution/dual
327+
for v in m.variables:
328+
assert_equal(
329+
c.variables[v].data[c.variables.dataset_attrs],
330+
m.variables[v].data[m.variables.dataset_attrs],
331+
)
332+
for con in m.constraints:
333+
assert_conequal(c.constraints[con], m.constraints[con], strict=False)
334+
assert_linequal(c.objective.expression, m.objective.expression)
335+
assert c.objective.sense == m.objective.sense

0 commit comments

Comments
 (0)