Skip to content

Commit c2c5ce7

Browse files
committed
Add arg validation for global constraints & tests
Introduce ExprLike type alias and tighten type hints for global constraint wrappers to accept Sequence[ExprLike]. Add helper methods (_as_non_empty_list, _assert_same_length) and assertions to ensure non-empty inputs and matching lengths (with special-case handling in regular for set f). Update many Constraint.* static methods to use these checks and normalize inputs. Add an integration test (tests/test_global_constraints_integration.py) that exercises all global constraint wrappers and verifies generated MiniZinc output. These changes improve input validation and catch misuse earlier.
1 parent c660e30 commit c2c5ce7

7 files changed

Lines changed: 382 additions & 42 deletions

File tree

pyproject.toml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,37 @@ markers = [
2626
"no_mzn_verify: skip global MiniZinc CLI validation for tests that intentionally render unresolved include paths",
2727
]
2828

29+
[tool.ruff]
30+
target-version = "py38"
31+
32+
[tool.ruff.lint]
33+
ignore = [
34+
"F403", # package exports and core modules intentionally use wildcard imports
35+
"F405", # project relies on re-export style imports in core DSL modules
36+
]
37+
38+
[tool.ruff.lint.per-file-ignores]
39+
"tests/**/*.py" = [
40+
"F401", # many scenario tests keep exploratory imports
41+
"F841", # tests sometimes keep named variables for readability
42+
"E741", # legacy variable names in classic combinatorics examples
43+
"E712", # explicit boolean equality kept in tests for intent clarity
44+
]
45+
46+
[tool.mypy]
47+
python_version = "3.12"
48+
ignore_missing_imports = true
49+
no_implicit_optional = false
50+
disable_error_code = [
51+
"override",
52+
"assignment",
53+
"arg-type",
54+
"return-value",
55+
"operator",
56+
"list-item",
57+
"attr-defined",
58+
]
59+
2960
[tool.coverage.run]
3061
source = ["pymzm"]
3162

src/pymzm/constraint.py

Lines changed: 131 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
from typing import List
2+
from typing import Sequence, Union
33

44
from .exceptions import *
55
from .expression import *
@@ -27,6 +27,8 @@ class AnnotationConstraint:
2727
]
2828

2929
class Constraint:
30+
ExprLike = Union[Expression, int, float, bool, str]
31+
3032
CTYPES = [
3133
CTYPE_NORMAL,
3234
CTYPE_ALLDIFFERENT,
@@ -113,11 +115,21 @@ def _to_mz(self):
113115
return f"constraint {self.cstr}{annotation_suffix};\n"
114116

115117
@staticmethod
116-
def _from_global_constraint(func: str, ctype: str, *args):
118+
def _from_global_constraint(func: str, ctype: str, *args) -> "Constraint":
117119
return Constraint(f"{func}({', '.join(str(a) for a in args)})", ctype)
118120

119121
@staticmethod
120-
def alldifferent(exprs: List[Expression]) -> "Constraint":
122+
def _as_non_empty_list(values, arg_name: str):
123+
values = list(values)
124+
assert len(values) > 0, f"{arg_name} cannot be empty"
125+
return values
126+
127+
@staticmethod
128+
def _assert_same_length(arg_name_left: str, left, arg_name_right: str, right):
129+
assert len(left) == len(right), f"{arg_name_left} and {arg_name_right} must have the same length"
130+
131+
@staticmethod
132+
def alldifferent(exprs: Sequence[ExprLike]) -> "Constraint":
121133
"""Constrain the elements in the passed List to be pairwise different.
122134
123135
Args:
@@ -126,120 +138,210 @@ def alldifferent(exprs: List[Expression]) -> "Constraint":
126138
Returns:
127139
Constraint: Alldifferent constraint
128140
"""
141+
exprs = Constraint._as_non_empty_list(exprs, "exprs")
129142
return Constraint._from_global_constraint("alldifferent", Constraint.CTYPE_ALLDIFFERENT, exprs)
130143

131144
@staticmethod
132-
def all_different(exprs: List[Expression]) -> "Constraint":
145+
def all_different(exprs: Sequence[ExprLike]) -> "Constraint":
133146
return Constraint.alldifferent(exprs)
134147

135148
@staticmethod
136-
def among(n: int, exprs: List[Expression], values: List[int]):
149+
def among(n: ExprLike, exprs: Sequence[ExprLike], values: Sequence[ExprLike]) -> "Constraint":
150+
exprs = Constraint._as_non_empty_list(exprs, "exprs")
151+
values = Constraint._as_non_empty_list(values, "values")
137152
return Constraint._from_global_constraint("among", Constraint.CTYPE_AMONG, n, exprs, values)
138153

139154
@staticmethod
140-
def all_equal(exprs: List[Expression]):
155+
def all_equal(exprs: Sequence[ExprLike]) -> "Constraint":
156+
exprs = Constraint._as_non_empty_list(exprs, "exprs")
141157
return Constraint._from_global_constraint("all_equal", Constraint.CTYPE_ALL_EQUAL, exprs)
142158

143159
@staticmethod
144-
def count(exprs: List[Expression], val: int, count: int):
160+
def count(exprs: Sequence[ExprLike], val: ExprLike, count: ExprLike) -> "Constraint":
161+
exprs = Constraint._as_non_empty_list(exprs, "exprs")
145162
return Constraint._from_global_constraint("count", Constraint.CTYPE_COUNT, exprs, val, count)
146163

147164
@staticmethod
148-
def increasing(exprs: List[Expression]):
165+
def increasing(exprs: Sequence[ExprLike]) -> "Constraint":
149166
# Requires that the array x is in (non-strictly) increasing order (duplicates are allowed).
167+
exprs = Constraint._as_non_empty_list(exprs, "exprs")
150168
return Constraint._from_global_constraint("increasing", Constraint.CTYPE_INCREASING, exprs)
151169

152170
@staticmethod
153-
def strictly_increasing(exprs: List[Expression]):
171+
def strictly_increasing(exprs: Sequence[ExprLike]) -> "Constraint":
172+
exprs = Constraint._as_non_empty_list(exprs, "exprs")
154173
return Constraint._from_global_constraint("strictly_increasing", Constraint.CTYPE_STRICTLY_INCREASING, exprs)
155174

156175
@staticmethod
157-
def decreasing(exprs: List[Expression]):
176+
def decreasing(exprs: Sequence[ExprLike]) -> "Constraint":
158177
# Requires that the array x is in (non-strictly) decreasing order (duplicates are allowed).
178+
exprs = Constraint._as_non_empty_list(exprs, "exprs")
159179
return Constraint._from_global_constraint("decreasing", Constraint.CTYPE_DECREASING, exprs)
160180

161181
@staticmethod
162-
def strictly_decreasing(exprs: List[Expression]):
182+
def strictly_decreasing(exprs: Sequence[ExprLike]) -> "Constraint":
183+
exprs = Constraint._as_non_empty_list(exprs, "exprs")
163184
return Constraint._from_global_constraint("strictly_decreasing", Constraint.CTYPE_STRICTLY_DECREASING, exprs)
164185

165186
@staticmethod
166-
def element(index: Expression, values: List[Expression], value: Expression):
187+
def element(index: ExprLike, values: Sequence[ExprLike], value: ExprLike) -> "Constraint":
188+
values = Constraint._as_non_empty_list(values, "values")
167189
return Constraint._from_global_constraint("element", Constraint.CTYPE_ELEMENT, index, values, value)
168190

169191
@staticmethod
170-
def table(exprs: List[Expression], rows):
192+
def table(exprs: Sequence[ExprLike], rows: Sequence[Sequence[ExprLike]]) -> "Constraint":
193+
exprs = Constraint._as_non_empty_list(exprs, "exprs")
194+
rows = Constraint._as_non_empty_list(rows, "rows")
195+
assert all(len(row) == len(exprs) for row in rows), "all table rows must match exprs width"
171196
return Constraint._from_global_constraint("table", Constraint.CTYPE_TABLE, exprs, rows)
172197

173198
@staticmethod
174-
def cumulative(s: List[Expression], d: List[Expression], r: List[Expression], b: Expression):
199+
def cumulative(s: Sequence[ExprLike], d: Sequence[ExprLike], r: Sequence[ExprLike], b: ExprLike) -> "Constraint":
200+
s = Constraint._as_non_empty_list(s, "s")
201+
d = Constraint._as_non_empty_list(d, "d")
202+
r = Constraint._as_non_empty_list(r, "r")
203+
Constraint._assert_same_length("s", s, "d", d)
204+
Constraint._assert_same_length("s", s, "r", r)
175205
return Constraint._from_global_constraint("cumulative", Constraint.CTYPE_CUMULATIVE, s, d, r, b)
176206

177207
@staticmethod
178-
def disjunctive(s: List[Expression], d: List[Expression]):
208+
def disjunctive(s: Sequence[ExprLike], d: Sequence[ExprLike]) -> "Constraint":
179209
# Requires that a set of tasks given by start times s and durations d do not overlap in time.
180210
# Tasks with duration 0 can be scheduled at any time, even in the middle of other tasks.
211+
s = Constraint._as_non_empty_list(s, "s")
212+
d = Constraint._as_non_empty_list(d, "d")
213+
Constraint._assert_same_length("s", s, "d", d)
181214
return Constraint._from_global_constraint("disjunctive", Constraint.CTYPE_DISJUNCTIVE, s, d)
182215

183216
@staticmethod
184-
def disjunctive_strict(s: List[Expression], d: List[Expression]):
217+
def disjunctive_strict(s: Sequence[ExprLike], d: Sequence[ExprLike]) -> "Constraint":
185218
# Requires that a set of tasks given by start times s and durations d do not overlap in time.
186219
# Tasks with duration 0 CANNOT be scheduled at any time, but only when no other task is running.
220+
s = Constraint._as_non_empty_list(s, "s")
221+
d = Constraint._as_non_empty_list(d, "d")
222+
Constraint._assert_same_length("s", s, "d", d)
187223
return Constraint._from_global_constraint("disjunctive_strict", Constraint.CTYPE_DISJUNCTIVE, s, d)
188224

189225
@staticmethod
190-
def circuit(successors: List[Expression]):
226+
def circuit(successors: Sequence[ExprLike]) -> "Constraint":
227+
successors = Constraint._as_non_empty_list(successors, "successors")
191228
return Constraint._from_global_constraint("circuit", Constraint.CTYPE_CIRCUIT, successors)
192229

193230
@staticmethod
194-
def path(successors: List[Expression], start: Expression, end: Expression):
231+
def path(successors: Sequence[ExprLike], start: ExprLike, end: ExprLike) -> "Constraint":
232+
successors = Constraint._as_non_empty_list(successors, "successors")
195233
return Constraint._from_global_constraint("path", Constraint.CTYPE_PATH, successors, start, end)
196234

197235
@staticmethod
198-
def bin_packing(capacity: Expression, bins: List[Expression], weights: List[Expression]):
236+
def bin_packing(capacity: ExprLike, bins: Sequence[ExprLike], weights: Sequence[ExprLike]) -> "Constraint":
237+
bins = Constraint._as_non_empty_list(bins, "bins")
238+
weights = Constraint._as_non_empty_list(weights, "weights")
239+
Constraint._assert_same_length("bins", bins, "weights", weights)
199240
return Constraint._from_global_constraint("bin_packing", Constraint.CTYPE_BIN_PACKING, capacity, bins, weights)
200241

201242
@staticmethod
202-
def inverse(forward: List[Expression], backward: List[Expression]):
243+
def inverse(forward: Sequence[ExprLike], backward: Sequence[ExprLike]) -> "Constraint":
244+
forward = Constraint._as_non_empty_list(forward, "forward")
245+
backward = Constraint._as_non_empty_list(backward, "backward")
246+
Constraint._assert_same_length("forward", forward, "backward", backward)
203247
return Constraint._from_global_constraint("inverse", Constraint.CTYPE_INVERSE, forward, backward)
204248

205249
@staticmethod
206-
def lex_less(left: List[Expression], right: List[Expression]):
250+
def lex_less(left: Sequence[ExprLike], right: Sequence[ExprLike]) -> "Constraint":
251+
left = Constraint._as_non_empty_list(left, "left")
252+
right = Constraint._as_non_empty_list(right, "right")
253+
Constraint._assert_same_length("left", left, "right", right)
207254
return Constraint._from_global_constraint("lex_less", Constraint.CTYPE_LEX_LESS, left, right)
208255

209256
@staticmethod
210-
def lex_lesseq(left: List[Expression], right: List[Expression]):
257+
def lex_lesseq(left: Sequence[ExprLike], right: Sequence[ExprLike]) -> "Constraint":
258+
left = Constraint._as_non_empty_list(left, "left")
259+
right = Constraint._as_non_empty_list(right, "right")
260+
Constraint._assert_same_length("left", left, "right", right)
211261
return Constraint._from_global_constraint("lex_lesseq", Constraint.CTYPE_LEX_LESSEQ, left, right)
212262

213263
@staticmethod
214-
def lex_greater(left: List[Expression], right: List[Expression]):
264+
def lex_greater(left: Sequence[ExprLike], right: Sequence[ExprLike]) -> "Constraint":
265+
left = Constraint._as_non_empty_list(left, "left")
266+
right = Constraint._as_non_empty_list(right, "right")
267+
Constraint._assert_same_length("left", left, "right", right)
215268
return Constraint._from_global_constraint("lex_greater", Constraint.CTYPE_LEX_GREATER, left, right)
216269

217270
@staticmethod
218-
def lex_greatereq(left: List[Expression], right: List[Expression]):
271+
def lex_greatereq(left: Sequence[ExprLike], right: Sequence[ExprLike]) -> "Constraint":
272+
left = Constraint._as_non_empty_list(left, "left")
273+
right = Constraint._as_non_empty_list(right, "right")
274+
Constraint._assert_same_length("left", left, "right", right)
219275
return Constraint._from_global_constraint("lex_greatereq", Constraint.CTYPE_LEX_GREATEREQ, left, right)
220276

221277
@staticmethod
222-
def regular(exprs: List[Expression], q: Expression, s: Expression, d, q0: Expression, f):
278+
def regular(
279+
exprs: Sequence[ExprLike],
280+
q: ExprLike,
281+
s: ExprLike,
282+
d: Sequence[Sequence[ExprLike]],
283+
q0: ExprLike,
284+
f: Sequence[ExprLike],
285+
) -> "Constraint":
286+
exprs = Constraint._as_non_empty_list(exprs, "exprs")
287+
d = Constraint._as_non_empty_list(d, "d")
288+
if (isinstance(f, set)):
289+
assert len(f) > 0, "f cannot be empty"
290+
else:
291+
f = Constraint._as_non_empty_list(f, "f")
223292
return Constraint._from_global_constraint("regular", Constraint.CTYPE_REGULAR, exprs, q, s, d, q0, f)
224293

225294
@staticmethod
226-
def arg_sort(x: List[Expression], p: List[Expression]):
295+
def arg_sort(x: Sequence[ExprLike], p: Sequence[ExprLike]) -> "Constraint":
227296
# Constrains p to be the permutation which causes x to be in sorted order hence x[p[i]] <= x[p[i+1]].
297+
x = Constraint._as_non_empty_list(x, "x")
298+
p = Constraint._as_non_empty_list(p, "p")
299+
Constraint._assert_same_length("x", x, "p", p)
228300
return Constraint._from_global_constraint("arg_sort", Constraint.CTYPE_ARG_SORT, x, p)
229301

230302
@staticmethod
231-
def diffn(x: List[Expression], y: List[Expression], dx: List[Expression], dy: List[Expression]):
303+
def diffn(x: Sequence[ExprLike], y: Sequence[ExprLike], dx: Sequence[ExprLike], dy: Sequence[ExprLike]) -> "Constraint":
232304
# Constrains rectangles i, given by their origins (x[i], y[i]) and sizes (dx[i], dy[i]),
233305
# to be non-overlapping. Zero-width rectangles can still not overlap with any other rectangle.
306+
x = Constraint._as_non_empty_list(x, "x")
307+
y = Constraint._as_non_empty_list(y, "y")
308+
dx = Constraint._as_non_empty_list(dx, "dx")
309+
dy = Constraint._as_non_empty_list(dy, "dy")
310+
Constraint._assert_same_length("x", x, "y", y)
311+
Constraint._assert_same_length("x", x, "dx", dx)
312+
Constraint._assert_same_length("x", x, "dy", dy)
234313
return Constraint._from_global_constraint("diffn", Constraint.CTYPE_DIFFN, x, y, dx, dy)
235314

236315
@staticmethod
237-
def connected(node_from: List[int], node_to: List[int], ns: List[ExpressionBool], es: List[ExpressionBool]):
316+
def connected(
317+
node_from: Sequence[int],
318+
node_to: Sequence[int],
319+
ns: Sequence[ExprLike],
320+
es: Sequence[ExprLike],
321+
) -> "Constraint":
238322
# Constrains the subgraph ns and es of a given undirected graph to be connected.
323+
node_from = Constraint._as_non_empty_list(node_from, "node_from")
324+
node_to = Constraint._as_non_empty_list(node_to, "node_to")
325+
ns = Constraint._as_non_empty_list(ns, "ns")
326+
es = Constraint._as_non_empty_list(es, "es")
327+
Constraint._assert_same_length("node_from", node_from, "node_to", node_to)
328+
Constraint._assert_same_length("node_from", node_from, "es", es)
239329
return Constraint._from_global_constraint("connected", Constraint.CTYPE_CONNECTED, node_from, node_to, ns, es)
240330

241331
@staticmethod
242-
def reachable(node_from: List[int], node_to: List[int], r: List[Expression], ns: List[ExpressionBool], es: List[ExpressionBool]):
332+
def reachable(
333+
node_from: Sequence[int],
334+
node_to: Sequence[int],
335+
r: ExprLike,
336+
ns: Sequence[ExprLike],
337+
es: Sequence[ExprLike],
338+
) -> "Constraint":
243339
# Constrains the subgraph ns and es of a given undirected graph to be reachable from r.
244340
# TODO: this can have other parameters.
341+
node_from = Constraint._as_non_empty_list(node_from, "node_from")
342+
node_to = Constraint._as_non_empty_list(node_to, "node_to")
343+
ns = Constraint._as_non_empty_list(ns, "ns")
344+
es = Constraint._as_non_empty_list(es, "es")
345+
Constraint._assert_same_length("node_from", node_from, "node_to", node_to)
346+
Constraint._assert_same_length("node_from", node_from, "es", es)
245347
return Constraint._from_global_constraint("reachable", Constraint.CTYPE_REACHABLE, node_from, node_to, r, ns, es)

src/pymzm/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,4 @@ class PymzmInvalidSearchAnnotation(PymzmException):
6666
def __init__(self):
6767
pass
6868
def __str__(self):
69-
return f"PymzmInvalidSearchAnnotation."
69+
return "PymzmInvalidSearchAnnotation."

src/pymzm/expression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def _func(cls, func_symbol: str, exprs):
386386
exprs2.append(expr.name)
387387
else:
388388
exprs2.append(expr)
389-
out = f", ".join(str(a) for a in exprs2)
389+
out = ", ".join(str(a) for a in exprs2)
390390
return cls(f"{func_symbol}({out})")
391391

392392
@staticmethod

src/pymzm/model.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,11 @@ def __init__(self, search_type: str, variables: List["Variable"], varchoice: str
131131
self.variables = variables
132132

133133
self.varchoice = varchoice
134-
if (not self.varchoice in AnnotationVariableChoice.VARCHOICES):
134+
if (self.varchoice not in AnnotationVariableChoice.VARCHOICES):
135135
raise PymzmInvalidVarchoiceAnnotation("varchoice")
136136

137137
self.valchoice = valchoice
138-
if (not self.valchoice in AnnotationValueChoice.VALCHOICES):
138+
if (self.valchoice not in AnnotationValueChoice.VALCHOICES):
139139
raise PymzmInvalidValchoiceAnnotation("valchoice")
140140

141141
def __str__(self):
@@ -328,25 +328,27 @@ def add_enum(self, type_name: str, members: List[str]) -> EnumDomain:
328328
self.enums[type_name] = enum_domain
329329
return enum_domain
330330

331-
def add_variable(self, name: str, vtype: int=Variable.VTYPE_INTEGER, val_min: int=None, val_max: int=None, domain: set=None, annotations=None):
331+
def add_variable(self, name: str, vtype: int=Variable.VTYPE_INTEGER, val_min: int=None, val_max: int=None, domain=None, annotations=None):
332332
variable = Variable(name, vtype, val_min, val_max, domain, annotations=annotations)
333333
self.variables.append(variable)
334334
return variable
335335

336-
def add_variables(self, name: str, indices: List[Tuple[int]], vtype: int=Variable.VTYPE_INTEGER, val_min: int=None, val_max: int=None, domains: set=None, annotations=None) -> ValueDict:
336+
def add_variables(self, name: str, indices: List[Tuple[int]], vtype: int=Variable.VTYPE_INTEGER, val_min: int=None, val_max: int=None, domains=None, annotations=None) -> ValueDict:
337337

338338
# Domain
339339
if (domains is None):
340-
domains = {}
340+
domain_map = {}
341341
elif (type(domains) is set):
342-
domains = {idx: domains for idx in indices}
342+
domain_map = {idx: domains for idx in indices}
343343
elif (type(domains) is list):
344344
assert len(domains) == len(indices)
345-
domains = {idx: domains[i] for i, idx in enumerate(indices)}
345+
domain_map = {idx: domains[i] for i, idx in enumerate(indices)}
346346
elif (type(domains) is dict):
347347
assert len(domains) == len(indices)
348348
assert set(domains.keys()) == set(indices)
349-
domains = {idx: domains[idx] for idx in indices}
349+
domain_map = {idx: domains[idx] for idx in indices}
350+
else:
351+
raise TypeError("domains must be None, set, list, or dict")
350352

351353
# Annotations
352354
if (annotations is None):
@@ -368,7 +370,7 @@ def add_variables(self, name: str, indices: List[Tuple[int]], vtype: int=Variabl
368370
vtype,
369371
val_min,
370372
val_max,
371-
domains.get(idx, None),
373+
domain_map.get(idx, None),
372374
annotations=annotations.get(idx, None),
373375
)
374376
self.variables.append(variable)

src/pymzm/result.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from enum import Enum
22
from typing import Any
33

4-
import minizinc
5-
64

75
_STATUS_MAP = {
86
"SATISFIED": "SAT",

0 commit comments

Comments
 (0)