Skip to content

Commit a2cd4b3

Browse files
committed
Add array/set comprehension support
Implement array and set comprehensions in Expression: add helpers to validate and render generator clauses (_generator_clause_to_mz), convert comprehension expressions (_comprehension_expr_to_mz), and public factories array_comprehension and set_comprehension. These methods validate inputs, accept callables that are invoked with generator variables, support optional generator predicates (rendered with "where"), and raise existing Pymzm exceptions on invalid input. Tests updated to cover rendering, use in model constraints, and input validation for comprehensions.
1 parent 495e85e commit a2cd4b3

2 files changed

Lines changed: 102 additions & 0 deletions

File tree

src/pymzm/expression.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,69 @@ def _predicate_to_mz(var_name: str, predicate) -> str:
219219

220220
raise PymzmValueIsNotCondition("predicate", predicate)
221221

222+
@staticmethod
223+
def _generator_clause_to_mz(generator, arg_name: str="generators") -> str:
224+
if (not isinstance(generator, (tuple, list)) or len(generator) not in (2, 3)):
225+
raise PymzmValueIsNotExpression(arg_name, generator)
226+
227+
var_name = generator[0]
228+
domain = generator[1]
229+
predicate = generator[2] if (len(generator) == 3) else None
230+
231+
if (not isinstance(var_name, str) or not var_name.strip()):
232+
raise PymzmValueIsNotExpression("var_name", var_name)
233+
234+
domain_mz = Expression._domain_to_mz(domain)
235+
clause = f"{var_name} in {domain_mz}"
236+
if (predicate is not None):
237+
predicate_mz = Expression._predicate_to_mz(var_name, predicate)
238+
clause += f" where {predicate_mz}"
239+
240+
return clause
241+
242+
@staticmethod
243+
def _comprehension_expr_to_mz(expr, generator_var_names):
244+
if (isinstance(expr, Callable)):
245+
vars_expr = [Expression(name) for name in generator_var_names]
246+
if (len(vars_expr) == 1):
247+
expr = expr(vars_expr[0])
248+
else:
249+
expr = expr(*vars_expr)
250+
251+
if (isinstance(expr, bool)):
252+
return "true" if expr else "false"
253+
254+
if (isinstance(expr, (Expression, int, float))):
255+
return str(expr)
256+
257+
raise PymzmValueIsNotExpression("expr", expr)
258+
259+
@staticmethod
260+
def array_comprehension(expr, generators):
261+
if (not isinstance(generators, Iterable) or isinstance(generators, (str, bytes))):
262+
raise PymzmValueIsNotExpression("generators", generators)
263+
264+
generators = list(generators)
265+
if (not len(generators)):
266+
raise PymzmNoValues("generators")
267+
268+
clauses = [Expression._generator_clause_to_mz(g) for g in generators]
269+
expr_mz = Expression._comprehension_expr_to_mz(expr, [g[0] for g in generators])
270+
return Expression(f"[{expr_mz} | {', '.join(clauses)}]")
271+
272+
@staticmethod
273+
def set_comprehension(expr, generators):
274+
if (not isinstance(generators, Iterable) or isinstance(generators, (str, bytes))):
275+
raise PymzmValueIsNotExpression("generators", generators)
276+
277+
generators = list(generators)
278+
if (not len(generators)):
279+
raise PymzmNoValues("generators")
280+
281+
clauses = [Expression._generator_clause_to_mz(g) for g in generators]
282+
expr_mz = Expression._comprehension_expr_to_mz(expr, [g[0] for g in generators])
283+
return Expression(f"{{{expr_mz} | {', '.join(clauses)}}}")
284+
222285
@staticmethod
223286
def predicate(name: str, *args) -> "ExpressionBool":
224287
if (not isinstance(name, str) or not name.strip()):

tests/test_expression_validation.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,45 @@ def test_expression_render_is_valid_mzn_when_embedded_in_model(self):
128128

129129
assert_valid_mzn(self, model.model_mzn_str)
130130

131+
def test_array_and_set_comprehension_render_with_filters(self):
132+
arr = pymzm.Expression.array_comprehension(
133+
lambda i: i + 1,
134+
[("i", range(1, 4), lambda i: i >= 2)],
135+
)
136+
s = pymzm.Expression.set_comprehension(
137+
lambda i, j: i + j,
138+
[("i", "1..2"), ("j", [1, 2, 3], lambda j: j >= 2)],
139+
)
140+
141+
self.assertEqual(str(arr), "[(i + 1) | i in 1..3 where (i >= 2)]")
142+
self.assertEqual(str(s), "{(i + j) | i in 1..2, j in {1, 2, 3} where (j >= 2)}")
143+
144+
def test_comprehensions_can_be_used_in_model_constraints(self):
145+
model = pymzm.Model()
146+
x = model.add_variable("x", val_min=0, val_max=10)
147+
148+
arr = pymzm.Expression.array_comprehension(lambda i: i + x, [("i", range(1, 3))])
149+
s = pymzm.Expression.set_comprehension(lambda i: i, [("i", [1, 2, 3], lambda i: i >= 2)])
150+
151+
model.add_constraint(x >= 0)
152+
model.add_constraint(pymzm.Expression._func("sum", [arr]) >= 0)
153+
model.add_constraint(pymzm.Expression._func("card", [s]) >= 1)
154+
model.set_solve_criteria(pymzm.SOLVE_SATISFY)
155+
model.generate()
156+
157+
self.assertIn("sum([", model.model_mzn_str)
158+
self.assertIn("{i | i in {1, 2, 3} where (i >= 2)}", model.model_mzn_str)
159+
assert_valid_mzn(self, model.model_mzn_str)
160+
161+
def test_comprehensions_validate_inputs(self):
162+
self.assertRaises(pymzm.PymzmNoValues, pymzm.Expression.array_comprehension, 1, [])
163+
self.assertRaises(pymzm.PymzmValueIsNotExpression, pymzm.Expression.array_comprehension, object(), [("i", [1])])
164+
self.assertRaises(pymzm.PymzmValueIsNotExpression, pymzm.Expression.array_comprehension, 1, ["bad"])
165+
self.assertRaises(pymzm.PymzmValueIsNotCondition, pymzm.Expression.array_comprehension, 1, [("i", [1], 7)])
166+
167+
self.assertRaises(pymzm.PymzmNoValues, pymzm.Expression.set_comprehension, 1, [])
168+
self.assertRaises(pymzm.PymzmValueIsNotExpression, pymzm.Expression.set_comprehension, object(), [("i", [1])])
169+
131170

132171
if __name__ == "__main__":
133172
unittest.main()

0 commit comments

Comments
 (0)