|
5 | 5 |
|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
| 8 | +import copy as pycopy |
8 | 9 | from pathlib import Path |
9 | 10 | from tempfile import gettempdir |
10 | 11 |
|
11 | 12 | import numpy as np |
12 | 13 | import pytest |
13 | 14 | import xarray as xr |
14 | 15 |
|
15 | | -from linopy import EQUAL, Model |
16 | | -from linopy.testing import assert_model_equal |
| 16 | +from linopy import EQUAL, Model, available_solvers |
| 17 | +from linopy.testing import ( |
| 18 | + assert_conequal, |
| 19 | + assert_equal, |
| 20 | + assert_linequal, |
| 21 | + assert_model_equal, |
| 22 | +) |
17 | 23 |
|
18 | 24 | target_shape: tuple[int, int] = (10, 10) |
19 | 25 |
|
@@ -163,3 +169,167 @@ def test_assert_model_equal() -> None: |
163 | 169 | m.add_objective(obj) |
164 | 170 |
|
165 | 171 | assert_model_equal(m, m) |
| 172 | + |
| 173 | + |
| 174 | +@pytest.fixture(scope="module") |
| 175 | +def copy_test_model() -> Model: |
| 176 | + """Small representative model used across copy tests.""" |
| 177 | + m: Model = Model() |
| 178 | + |
| 179 | + lower: xr.DataArray = xr.DataArray( |
| 180 | + np.zeros((10, 10)), coords=[range(10), range(10)] |
| 181 | + ) |
| 182 | + upper: xr.DataArray = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) |
| 183 | + x = m.add_variables(lower, upper, name="x") |
| 184 | + y = m.add_variables(name="y") |
| 185 | + |
| 186 | + m.add_constraints(1 * x + 10 * y, EQUAL, 0) |
| 187 | + m.add_objective((10 * x + 5 * y).sum()) |
| 188 | + |
| 189 | + return m |
| 190 | + |
| 191 | + |
| 192 | +@pytest.fixture(scope="module") |
| 193 | +def solved_copy_test_model(copy_test_model: Model) -> Model: |
| 194 | + """Solved representative model used across solved-copy tests.""" |
| 195 | + m = copy_test_model.copy(deep=True) |
| 196 | + m.solve() |
| 197 | + return m |
| 198 | + |
| 199 | + |
| 200 | +def test_model_copy_unsolved(copy_test_model: Model) -> None: |
| 201 | + """Copy of unsolved model is structurally equal and independent.""" |
| 202 | + m = copy_test_model.copy(deep=True) |
| 203 | + c = m.copy(include_solution=False) |
| 204 | + |
| 205 | + assert_model_equal(m, c) |
| 206 | + |
| 207 | + # independence: mutating copy does not affect source |
| 208 | + c.add_variables(name="z") |
| 209 | + assert "z" not in m.variables |
| 210 | + |
| 211 | + |
| 212 | +def test_model_copy_unsolved_with_solution_flag(copy_test_model: Model) -> None: |
| 213 | + """Unsolved model with include_solution=True has no extra solve artifacts.""" |
| 214 | + m = copy_test_model.copy(deep=True) |
| 215 | + |
| 216 | + c_include_solution = m.copy(include_solution=True) |
| 217 | + c_exclude_solution = m.copy(include_solution=False) |
| 218 | + |
| 219 | + assert_model_equal(c_include_solution, c_exclude_solution) |
| 220 | + assert c_include_solution.status == "initialized" |
| 221 | + assert c_include_solution.termination_condition == "" |
| 222 | + assert c_include_solution.objective.value is None |
| 223 | + |
| 224 | + |
| 225 | +def test_model_copy_shallow(copy_test_model: Model) -> None: |
| 226 | + """Shallow copy has independent wrappers sharing underlying data buffers.""" |
| 227 | + m = copy_test_model.copy(deep=True) |
| 228 | + c = m.copy(deep=False) |
| 229 | + |
| 230 | + assert c is not m |
| 231 | + assert c.variables is not m.variables |
| 232 | + assert c.constraints is not m.constraints |
| 233 | + assert c.objective is not m.objective |
| 234 | + |
| 235 | + # wrappers are distinct, but shallow copy shares payload buffers |
| 236 | + c.variables["x"].lower.values[0, 0] = 123.0 |
| 237 | + assert m.variables["x"].lower.values[0, 0] == 123.0 |
| 238 | + |
| 239 | + |
| 240 | +def test_model_deepcopy_protocol(copy_test_model: Model) -> None: |
| 241 | + """copy.deepcopy(model) dispatches to Model.__deepcopy__ and stays independent.""" |
| 242 | + m = copy_test_model.copy(deep=True) |
| 243 | + c = pycopy.deepcopy(m) |
| 244 | + |
| 245 | + assert_model_equal(m, c) |
| 246 | + |
| 247 | + # Test independence: mutations to copy do not affect source |
| 248 | + # 1. Variable mutation: add new variable |
| 249 | + c.add_variables(name="z") |
| 250 | + assert "z" not in m.variables |
| 251 | + |
| 252 | + # 2. Variable data mutation (bounds): verify buffers are independent |
| 253 | + original_lower = m.variables["x"].lower.values[0, 0].item() |
| 254 | + new_lower = 999 |
| 255 | + c.variables["x"].lower.values[0, 0] = new_lower |
| 256 | + assert c.variables["x"].lower.values[0, 0] == new_lower |
| 257 | + assert m.variables["x"].lower.values[0, 0] == original_lower |
| 258 | + |
| 259 | + # 3. Constraint coefficient mutation: deep copy must not leak back |
| 260 | + original_con_coeff = m.constraints["con0"].coeffs.values.flat[0].item() |
| 261 | + new_con_coeff = original_con_coeff + 42 |
| 262 | + c.constraints["con0"].coeffs.values.flat[0] = new_con_coeff |
| 263 | + assert c.constraints["con0"].coeffs.values.flat[0] == new_con_coeff |
| 264 | + assert m.constraints["con0"].coeffs.values.flat[0] == original_con_coeff |
| 265 | + |
| 266 | + # 4. Objective expression coefficient mutation: deep copy must not leak back |
| 267 | + original_obj_coeff = m.objective.expression.coeffs.values.flat[0].item() |
| 268 | + new_obj_coeff = original_obj_coeff + 20 |
| 269 | + c.objective.expression.coeffs.values.flat[0] = new_obj_coeff |
| 270 | + assert c.objective.expression.coeffs.values.flat[0] == new_obj_coeff |
| 271 | + assert m.objective.expression.coeffs.values.flat[0] == original_obj_coeff |
| 272 | + |
| 273 | + # 5. Objective sense mutation |
| 274 | + original_sense = m.objective.sense |
| 275 | + c.objective.sense = "max" |
| 276 | + assert c.objective.sense == "max" |
| 277 | + assert m.objective.sense == original_sense |
| 278 | + |
| 279 | + |
| 280 | +@pytest.mark.skipif(not available_solvers, reason="No solver installed") |
| 281 | +class TestModelCopySolved: |
| 282 | + def test_model_deepcopy_protocol_excludes_solution( |
| 283 | + self, solved_copy_test_model: Model |
| 284 | + ) -> None: |
| 285 | + """copy.deepcopy on solved model drops solve state by default.""" |
| 286 | + m = solved_copy_test_model |
| 287 | + |
| 288 | + c = pycopy.deepcopy(m) |
| 289 | + |
| 290 | + assert c.status == "initialized" |
| 291 | + assert c.termination_condition == "" |
| 292 | + assert c.objective.value is None |
| 293 | + |
| 294 | + for v in m.variables: |
| 295 | + assert_equal( |
| 296 | + c.variables[v].data[c.variables.dataset_attrs], |
| 297 | + m.variables[v].data[m.variables.dataset_attrs], |
| 298 | + ) |
| 299 | + for con in m.constraints: |
| 300 | + assert_conequal(c.constraints[con], m.constraints[con], strict=False) |
| 301 | + assert_linequal(c.objective.expression, m.objective.expression) |
| 302 | + assert c.objective.sense == m.objective.sense |
| 303 | + |
| 304 | + def test_model_copy_solved_with_solution( |
| 305 | + self, solved_copy_test_model: Model |
| 306 | + ) -> None: |
| 307 | + """Copy with include_solution=True preserves solve state.""" |
| 308 | + m = solved_copy_test_model |
| 309 | + |
| 310 | + c = m.copy(include_solution=True) |
| 311 | + assert_model_equal(m, c) |
| 312 | + |
| 313 | + def test_model_copy_solved_without_solution( |
| 314 | + self, solved_copy_test_model: Model |
| 315 | + ) -> None: |
| 316 | + """Copy with include_solution=False (default) drops solve state but preserves problem structure.""" |
| 317 | + m = solved_copy_test_model |
| 318 | + |
| 319 | + c = m.copy(include_solution=False) |
| 320 | + |
| 321 | + # solve state is dropped |
| 322 | + assert c.status == "initialized" |
| 323 | + assert c.termination_condition == "" |
| 324 | + assert c.objective.value is None |
| 325 | + |
| 326 | + # problem structure is preserved — compare only dataset_attrs to exclude solution/dual |
| 327 | + for v in m.variables: |
| 328 | + assert_equal( |
| 329 | + c.variables[v].data[c.variables.dataset_attrs], |
| 330 | + m.variables[v].data[m.variables.dataset_attrs], |
| 331 | + ) |
| 332 | + for con in m.constraints: |
| 333 | + assert_conequal(c.constraints[con], m.constraints[con], strict=False) |
| 334 | + assert_linequal(c.objective.expression, m.objective.expression) |
| 335 | + assert c.objective.sense == m.objective.sense |
0 commit comments