Skip to content

Commit ea28e5e

Browse files
committed
Light changes
1 parent f4de750 commit ea28e5e

2 files changed

Lines changed: 30 additions & 14 deletions

File tree

src/pyscipopt/recipes/getLocalConss.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,19 @@
22

33
def getLocalConss(model: Model, node = None) -> list[Constraint]:
44
"""
5-
Returns the local constraints of a node.
5+
Returns local constraints.
6+
7+
Parameters
8+
----------
9+
model : Model
10+
The model from which to retrieve the local constraints.
11+
node : Node, optional
12+
The node from which to retrieve the local constraints. If not provided, the current node is used.
13+
14+
Returns
15+
-------
16+
list[Constraint]
17+
A list of local constraints. First entry are global constraints, second entry are all the added constraints.
618
"""
719

820
if not node:
@@ -11,14 +23,15 @@ def getLocalConss(model: Model, node = None) -> list[Constraint]:
1123
else:
1224
cur_node = node
1325

14-
local_conss = []
26+
added_conss = []
1527
while cur_node is not None:
16-
local_conss = cur_node.getAddedConss() + local_conss
28+
added_conss = cur_node.getAddedConss() + added_conss
1729
cur_node = cur_node.getParent()
18-
return local_conss
30+
31+
return [model.getConss(), added_conss]
1932

20-
def getNLocalConss(model: Model) -> int:
33+
def getNAddedConss(model: Model) -> int:
2134
"""
2235
Returns the number of local constraints of a node.
2336
"""
24-
return len(getLocalConss(model))
37+
return len(getLocalConss(model)[1])

tests/test_recipe_getLocalConss.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,28 @@
44

55
def localconss(model, event):
66
local_conss = getLocalConss(model)
7-
assert len(local_conss) == getNLocalConss(model)
7+
assert len(local_conss[1]) == getNAddedConss(model)
88

99
vars = model.getVars()
1010
if model.getCurrentNode().getNumber() == 1:
1111
pass
12+
1213
elif model.getCurrentNode().getNumber() == 2:
1314
model.data["local_cons1"] = model.addCons(vars[0] + vars[1] <= 1, name="c1", local=True)
14-
assert getNLocalConss(model) == 1
15-
assert getLocalConss(model)[0] == model.data["local_cons1"]
15+
assert getNAddedConss(model) == 1
16+
assert getLocalConss(model)[1][0] == model.data["local_cons1"]
17+
1618
elif model.getCurrentNode().getNumber() == 4:
1719
local_conss = getLocalConss(model)
1820
model.data["local_cons2"] = model.addCons(vars[1] + vars[2] <= 1, name="c2", local=True)
1921
model.data["local_cons3"] = model.addCons(vars[2] + vars[3] <= 1, name="c3", local=True)
20-
assert getNLocalConss(model) == 3
21-
assert getLocalConss(model)[0] == model.data["local_cons1"]
22-
assert getLocalConss(model)[1] == model.data["local_cons2"]
23-
assert getLocalConss(model)[2] == model.data["local_cons3"]
22+
assert getNAddedConss(model) == 3
23+
assert getLocalConss(model)[1][0] == model.data["local_cons1"]
24+
assert getLocalConss(model)[1][1] == model.data["local_cons2"]
25+
assert getLocalConss(model)[1][2] == model.data["local_cons3"]
26+
2427
elif model.getCurrentNode().getParent().getNumber() not in [2,4]:
25-
assert getLocalConss(model) == []
28+
assert getLocalConss(model) == [model.getConss(), []]
2629

2730
def test_getLocalConss():
2831
model = random_mip_1(node_lim=4)

0 commit comments

Comments
 (0)