Skip to content

Commit c7d36b9

Browse files
committed
Merge remote-tracking branch 'origin/master' into realtime-trace-jsonl
2 parents fc869bc + ef78d29 commit c7d36b9

4 files changed

Lines changed: 67 additions & 37 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- Wrapped isObjIntegral() and test
99
- Added structured_optimization_trace recipe for structured optimization progress tracking
1010
- Added methods: getPrimalDualIntegral()
11+
- getSolVal() supports MatrixExpr now
1112
- Added realtime_trace_jsonl recipe for real-time optimization progress tracking with JSONL streaming output
1213
### Fixed
1314
- getBestSol() now returns None for infeasible problems instead of a Solution with NULL pointer

src/pyscipopt/scip.pxi

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,16 @@ cdef class Solution:
10981098
sol.scip = scip
10991099
return sol
11001100

1101-
def __getitem__(self, expr: Union[Expr, MatrixExpr]):
1101+
def __getitem__(
1102+
self,
1103+
expr: Union[Expr, GenExpr, MatrixExpr],
1104+
) -> Union[float, np.ndarray]:
1105+
if not isinstance(expr, (Expr, GenExpr, MatrixExpr)):
1106+
raise TypeError(
1107+
"Argument 'expr' has incorrect type, expected 'Expr', 'GenExpr', or "
1108+
f"'MatrixExpr', got {type(expr).__name__!r}"
1109+
)
1110+
11021111
self._checkStage("SCIPgetSolVal")
11031112
return expr._evaluate(self)
11041113

@@ -10968,75 +10977,63 @@ cdef class Model:
1096810977
def getSolVal(
1096910978
self,
1097010979
Solution sol,
10971-
expr: Union[Expr, GenExpr],
10980+
expr: Union[Expr, GenExpr, MatrixExpr],
1097210981
) -> Union[float, np.ndarray]:
1097310982
"""
10974-
Retrieve value of given variable or expression in the given solution or in
10975-
the LP/pseudo solution if sol == None
10983+
Retrieve value of given variable or expression in the given solution.
1097610984

1097710985
Parameters
1097810986
----------
1097910987
sol : Solution
10980-
expr : Expr
10981-
polynomial expression to query the value of
10988+
Solution to query the value from. If None, the current LP/pseudo solution is
10989+
used.
10990+
10991+
expr : Expr, GenExpr, MatrixExpr
10992+
Expression to query the value of.
1098210993

1098310994
Returns
1098410995
-------
10985-
float
10996+
float or np.ndarray
1098610997

1098710998
Notes
1098810999
-----
1098911000
A variable is also an expression.
1099011001

1099111002
"""
10992-
if not isinstance(expr, (Expr, GenExpr)):
10993-
raise TypeError(
10994-
"Argument 'expr' has incorrect type (expected 'Expr' or 'GenExpr', "
10995-
f"got {type(expr)})"
10996-
)
1099711003
# no need to create a NULL solution wrapper in case we have a variable
1099811004
return (sol or Solution.create(self._scip, NULL))[expr]
1099911005

11000-
def getVal(self, expr: Union[Expr, GenExpr, MatrixExpr] ):
11006+
def getVal(self, expr: Union[Expr, GenExpr, MatrixExpr]) -> Union[float, np.ndarray]:
1100111007
"""
1100211008
Retrieve the value of the given variable or expression in the best known solution.
1100311009
Can only be called after solving is completed.
1100411010

1100511011
Parameters
1100611012
----------
1100711013
expr : Expr, GenExpr or MatrixExpr
11014+
Expression to query the value of.
1100811015

1100911016
Returns
1101011017
-------
11011-
float
11018+
float or np.ndarray
1101211019

1101311020
Notes
1101411021
-----
1101511022
A variable is also an expression.
1101611023

1101711024
"""
11018-
cdef SCIP_SOL* current_best_sol
11019-
11020-
stage_check = SCIPgetStage(self._scip) not in [SCIP_STAGE_INIT, SCIP_STAGE_FREE]
11021-
if not stage_check:
11025+
if SCIPgetStage(self._scip) in {SCIP_STAGE_INIT, SCIP_STAGE_FREE}:
1102211026
raise Warning("Method cannot be called in stage ", self.getStage())
1102311027

1102411028
# Ensure _bestSol is up-to-date (cheap pointer comparison)
11025-
current_best_sol = SCIPgetBestSol(self._scip)
11029+
cdef SCIP_SOL* current_best_sol = SCIPgetBestSol(self._scip)
1102611030
if self._bestSol is None or self._bestSol.sol != current_best_sol:
1102711031
self._bestSol = Solution.create(self._scip, current_best_sol)
1102811032

1102911033
if self._bestSol.sol == NULL and SCIPgetStage(self._scip) != SCIP_STAGE_SOLVING:
1103011034
raise Warning("No solution available")
1103111035

11032-
if isinstance(expr, MatrixExpr):
11033-
result = np.empty(expr.shape, dtype=float)
11034-
for idx in np.ndindex(result.shape):
11035-
result[idx] = self.getSolVal(self._bestSol, expr[idx])
11036-
else:
11037-
result = self.getSolVal(self._bestSol, expr)
11038-
11039-
return result
11036+
return self._bestSol[expr]
1104011037

1104111038
def hasPrimalRay(self):
1104211039
"""

src/pyscipopt/scip.pyi

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import ClassVar
1+
from typing import ClassVar, Union, overload
22

3-
import numpy
3+
import numpy as np
44
from _typeshed import Incomplete
55
from typing_extensions import disjoint_base
66

@@ -496,7 +496,7 @@ class LP:
496496
def solve(self, dual: Incomplete = ...) -> Incomplete: ...
497497
def writeLP(self, filename: Incomplete) -> Incomplete: ...
498498

499-
class MatrixConstraint(numpy.ndarray):
499+
class MatrixConstraint(np.ndarray):
500500
def getConshdlrName(self) -> Incomplete: ...
501501
def isActive(self) -> Incomplete: ...
502502
def isChecked(self) -> Incomplete: ...
@@ -512,7 +512,7 @@ class MatrixConstraint(numpy.ndarray):
512512
def isSeparated(self) -> Incomplete: ...
513513
def isStickingAtNode(self) -> Incomplete: ...
514514

515-
class MatrixExpr(numpy.ndarray):
515+
class MatrixExpr(np.ndarray):
516516
def _evaluate(self, sol: Incomplete) -> Incomplete: ...
517517
def __array_ufunc__(
518518
self,
@@ -522,7 +522,7 @@ class MatrixExpr(numpy.ndarray):
522522
**kwargs: Incomplete,
523523
) -> Incomplete: ...
524524

525-
class MatrixExprCons(numpy.ndarray):
525+
class MatrixExprCons(np.ndarray):
526526
def __array_ufunc__(
527527
self,
528528
ufunc: Incomplete,
@@ -1215,7 +1215,10 @@ class Model:
12151215
self, sol: Incomplete, original: Incomplete = ...
12161216
) -> Incomplete: ...
12171217
def getSolTime(self, sol: Incomplete) -> Incomplete: ...
1218-
def getSolVal(self, sol: Incomplete, expr: Incomplete) -> Incomplete: ...
1218+
@overload
1219+
def getSolVal(self, sol: Solution, expr: Union[Expr, GenExpr]) -> float: ...
1220+
@overload
1221+
def getSolVal(self, sol: Solution, expr: MatrixExpr) -> np.ndarray: ...
12191222
def getSols(self) -> Incomplete: ...
12201223
def getSolvingTime(self) -> Incomplete: ...
12211224
def getStage(self) -> Incomplete: ...
@@ -1227,7 +1230,10 @@ class Model:
12271230
def getTransformedCons(self, cons: Incomplete) -> Incomplete: ...
12281231
def getTransformedVar(self, var: Incomplete) -> Incomplete: ...
12291232
def getTreesizeEstimation(self) -> Incomplete: ...
1230-
def getVal(self, expr: Incomplete) -> Incomplete: ...
1233+
@overload
1234+
def getVal(self, expr: Union[Expr, GenExpr]) -> float: ...
1235+
@overload
1236+
def getVal(self, expr: MatrixExpr) -> np.ndarray: ...
12311237
def getValsLinear(self, cons: Incomplete) -> Incomplete: ...
12321238
def getVarDict(self, transformed: Incomplete = ...) -> Incomplete: ...
12331239
def getVarLbDive(self, var: Incomplete) -> Incomplete: ...

tests/test_model.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
import pytest
2-
import os
31
import itertools
2+
import os
3+
4+
import numpy as np
5+
import pytest
46

5-
from pyscipopt import Model, SCIP_STAGE, SCIP_PARAMSETTING, SCIP_BRANCHDIR, quicksum
7+
from pyscipopt import SCIP_BRANCHDIR, SCIP_PARAMSETTING, SCIP_STAGE, Model, quicksum
68
from helpers.utils import random_mip_1
79

10+
811
def test_model():
912
# create solver instance
1013
s = Model()
@@ -616,3 +619,26 @@ def create_model_and_get_objects():
616619

617620
assert repr(x) == ""
618621
assert repr(c) == ""
622+
623+
624+
def test_getSolVal():
625+
# fix #1136
626+
627+
m = Model()
628+
x = m.addVar(vtype="B")
629+
y = m.addMatrixVar(2, vtype="B")
630+
631+
m.setObjective(x + y.sum())
632+
m.optimize()
633+
sol = m.getBestSol()
634+
635+
assert m.getSolVal(sol, x) == m.getVal(x)
636+
assert m.getVal(x) == 0
637+
638+
assert np.array_equal(m.getSolVal(sol, y), m.getVal(y))
639+
assert np.array_equal(m.getVal(y), np.array([0, 0]))
640+
641+
with pytest.raises(TypeError):
642+
m.getVal("not_a_var")
643+
with pytest.raises(TypeError):
644+
m.getSolVal(sol, "not_a_var")

0 commit comments

Comments
 (0)