diff --git a/cpmpy/expressions/core.py b/cpmpy/expressions/core.py index ba5f36c67..6693988ba 100644 --- a/cpmpy/expressions/core.py +++ b/cpmpy/expressions/core.py @@ -93,7 +93,7 @@ import copy import warnings from dataclasses import dataclass -from typing import Any, Final, Optional, TypeAlias, TypeVar, Union, Sequence, Iterable +from typing import Any, Callable, Final, Optional, TypeAlias, TypeVar, Union, Sequence, Iterable from frozendict import frozendict import numpy as np import cpmpy as cp @@ -168,26 +168,28 @@ def update_args(self, args: Iterable[Any], has_subexpr: Optional[bool] = None) - def set_description(self, txt: str, override_print: bool = True, full_print: bool = False) -> None: self._description = Description(txt, override_print, full_print) + def _to_string(self, str_func: Callable[[Any], str]) -> str: + strargs = [] + for arg in self.args: + if isinstance(arg, np.ndarray): + # flatten + strarg = ", ".join(map(str_func, arg.flat)) # with space to match list printing + strargs.append(f"[{strarg}]") + else: + strargs.append(str_func(arg)) + return "{}({})".format(self.name, ",".join(strargs)) + def __str__(self) -> str: d = self._description if d is None or not d.override_print: - return self.__repr__() + return self._to_string(str_func=str) out = d.text if d.full_print: out += " -- " + self.__repr__() return out - def __repr__(self) -> str: - strargs = [] - for arg in self.args: - if isinstance(arg, np.ndarray): - # flatten - strarg = ",".join(map(str, arg.flat)) - strargs.append(f"[{strarg}]") - else: - strargs.append(f"{arg}") - return "{}({})".format(self.name, ",".join(strargs)) + return self._to_string(str_func=repr) def __hash__(self) -> int: return hash(self.__repr__()) @@ -612,11 +614,11 @@ def __init__(self, name: str, left: ExprLike, right: ExprLike) -> None: assert (name in Comparison.allowed), f"Symbol {name} not allowed" super().__init__(name, (left, right)) - def __repr__(self) -> str: + def _to_string(self, str_func: Callable[[Any], str]) -> str: if all(isinstance(x, Expression) for x in self.args): - return "({}) {} ({})".format(self.args[0], self.name, self.args[1]) + return "({}) {} ({})".format(str_func(self.args[0]), self.name, str_func(self.args[1])) # if not: prettier printing without braces - return "{} {} {}".format(self.args[0], self.name, self.args[1]) + return "{} {} {}".format(str_func(self.args[0]), self.name, str_func(self.args[1])) def __bool__(self) -> bool: # will be called when comparing elements in a container, but always with `==` @@ -729,26 +731,26 @@ def is_bool(self) -> bool: """ return Operator.allowed[self.name][1] - def __repr__(self) -> str: + def _to_string(self, str_func: Callable[[Any], str]) -> str: # special cases if self.name == '-': # unary - - return "-({})".format(self.args[0]) + return "-({})".format(str_func(self.args[0])) # weighted sum if self.name == 'wsum': - return f"sum({self.args[0]} * {self.args[1]})" + return f"sum({self.args[0]} * {str_func(self.args[1])})" if len(self.args) == 1: - return "{}({})".format(self.name, self.args[0]) # tuple of size 1 omitted in print + return "{}({})".format(self.name, str_func(self.args[0])) # tuple of size 1 omitted in print elif len(self.args) == 2: # infix printing of two arguments printname = Operator.printmap.get(self.name, self.name) # default to self.name if not in printmap arg0, arg1 = self.args - str_arg0 = f"({arg0})" if isinstance(arg0, Expression) else str(arg0) - str_arg1 = f"({arg1})" if isinstance(arg1, Expression) else str(arg1) + str_arg0 = f"({str_func(arg0)})" if isinstance(arg0, Expression) else str(arg0) + str_arg1 = f"({str_func(arg1)})" if isinstance(arg1, Expression) else str(arg1) return f"{str_arg0} {printname} {str_arg1}" else: # n-ary - return "{}{}".format(self.name, self.args) # args is a tuple, will be in () + return "{}{}".format(self.name, str_func(self.args)) # args is a tuple, will be in () def value(self) -> Optional[int]: """ diff --git a/cpmpy/expressions/globalconstraints.py b/cpmpy/expressions/globalconstraints.py index aed6c78d4..5f18c6bf9 100644 --- a/cpmpy/expressions/globalconstraints.py +++ b/cpmpy/expressions/globalconstraints.py @@ -133,7 +133,7 @@ def my_circuit_decomp(self): """ import warnings -from typing import cast, Literal, Optional, Iterable, Any, TYPE_CHECKING +from typing import Callable, cast, Literal, Optional, Iterable, Any, TYPE_CHECKING import numpy as np import cpmpy as cp @@ -871,9 +871,9 @@ def decompose(self) -> tuple[list[Expression], list[Expression]]: condition = cp.BoolVal(condition) # ensure it is a CPMpy expression return [condition.implies(if_true), (~condition).implies(if_false)], [] - def __repr__(self) -> str: + def _to_string(self, str_func: Callable[[Any], str]) -> str: condition, if_true, if_false = self.args - return "If {} Then {} Else {}".format(condition, if_true, if_false) + return "If {} Then {} Else {}".format(str_func(condition), str_func(if_true), str_func(if_false)) def negate(self) -> Expression: return IfThenElse(self.args[0], self.args[2], self.args[1]) @@ -929,9 +929,9 @@ def value(self) -> Optional[bool]: return None return bool(np.any(arr == exprval)) - def __repr__(self) -> str: + def _to_string(self, str_func: Callable[[Any], str]) -> str: expr, arr = self.args - return "{} in {}".format(expr, arr) + return "{} in {}".format(str_func(expr), str_func(arr)) def negate(self) -> Expression: expr, arr = self.args @@ -1009,10 +1009,10 @@ def value(self) -> Optional[bool]: return None return sum(arrvals) % 2 == 1 - def __repr__(self) -> str: + def _to_string(self, str_func: Callable[[Any], str]) -> str: if len(self.args) == 2: - return "{} xor {}".format(*self.args) - return "xor({})".format(self.args) + return "{} xor {}".format(str_func(self.args[0]), str_func(self.args[1])) + return "xor({})".format(str_func(self.args)) def negate(self) -> Expression: # negate one of the arguments, ideally a variable diff --git a/cpmpy/expressions/globalfunctions.py b/cpmpy/expressions/globalfunctions.py index 567cca85f..15638ca6d 100644 --- a/cpmpy/expressions/globalfunctions.py +++ b/cpmpy/expressions/globalfunctions.py @@ -73,13 +73,13 @@ def decompose(self): """ import warnings # for deprecation warning -from typing import Optional, Iterable +from typing import Any, Callable, Optional, Iterable import numpy as np import cpmpy as cp from ..exceptions import CPMpyException, IncompleteFunctionError, TypeError from .core import Expression, Operator, ExprLike, ListLike -from .variables import intvar, NDVarArray, _NumVarImpl, BoolVal +from .variables import cpm_array, intvar, NDVarArray, _NumVarImpl, BoolVal from .utils import argval, is_num, eval_comparison, is_any_list, is_boolexpr, get_bounds, argvals, implies, argvals_intexpr, get_bounds_intexpr, npint2int @@ -379,13 +379,13 @@ def update_args(self, args): super().update_args((x, y)) self.is_lhs_num = is_lhs_num - def __repr__(self): + def _to_string(self, str_func: Callable[[Any], str]) -> str: x, y = self.args if self.is_lhs_num: - return "{} * ({})".format(x, y) + return "{} * ({})".format(str_func(x), str_func(y)) - return "({}) * ({})".format(x, y) + return "({}) * ({})".format(str_func(x), str_func(y)) def __neg__(self): """-(c*x) -> (-c)*x when constant c is first (.is_lhs_num).""" @@ -481,14 +481,14 @@ def __init__(self, x: ExprLike, y: ExprLike): """ super().__init__("div", (x, y)) - def __repr__(self): + def _to_string(self, str_func: Callable[[Any], str]) -> str: """ Returns: str: String representation of integer division as 'x div y' """ x,y = self.args - return "{} div {}".format(f"({x})" if isinstance(x, Expression) else x, - f"({y})" if isinstance(y, Expression) else y) + return "{} div {}".format(f"({str_func(x)})" if isinstance(x, Expression) else str_func(x), + f"({str_func(y)})" if isinstance(y, Expression) else str_func(y)) def decompose(self): """ @@ -576,14 +576,14 @@ def __init__(self, x: ExprLike, y: ExprLike): """ super().__init__("mod", (x, y)) - def __repr__(self): + def _to_string(self, str_func: Callable[[Any], str]) -> str: """ Returns: str: String representation with 'mod' as notation """ x,y = self.args - return "{} mod {}".format(f"({x})" if isinstance(x, Expression) else x, - f"({y})" if isinstance(y, Expression) else y) + return "{} mod {}".format(f"({str_func(x)})" if isinstance(x, Expression) else str_func(x), + f"({str_func(y)})" if isinstance(y, Expression) else str_func(y)) def decompose(self): """ @@ -814,14 +814,19 @@ def get_bounds(self) -> tuple[int, int]: bnds = [get_bounds(x) for x in arr] return min(lb for lb,ub in bnds), max(ub for lb,ub in bnds) - def __repr__(self) -> str: + def _to_string(self, str_func: Callable[[Any], str]) -> str: """ Custom string representation of the Element global function in 'Arr[Idx]' format. Returns: str: String representation of the Element global function. """ - return f"{self.args[0]}[{self.args[1]}]" + arr, idx = self.args + if isinstance(arr, np.ndarray): + str_arr = str_func(arr.tolist()) # overkill, also converts np int to Python int + else: + str_arr = str_func(arr) + return f"{str_arr}[{str_func(idx)}]" def element(arg_list): """ @@ -853,6 +858,7 @@ def __init__(self, arr: ListLike[ExprLike], val: ExprLike): raise TypeError(f"Count(arr, val) takes an array of expressions as first argument, not: {arr}") if is_any_list(val): raise TypeError(f"Count(arr, val) takes a numeric expression as second argument, not a list: {val}") + arr = cpm_array(arr) super().__init__("count", (arr, val)) def decompose(self) -> tuple[Expression, list[Expression]]: diff --git a/cpmpy/expressions/variables.py b/cpmpy/expressions/variables.py index 80c93b36e..5ab4bc33f 100644 --- a/cpmpy/expressions/variables.py +++ b/cpmpy/expressions/variables.py @@ -61,7 +61,7 @@ from collections.abc import Iterable import warnings # for deprecation warning from functools import reduce -from typing import Any, Literal, Optional, overload +from typing import Any, Callable, Literal, Optional, overload import numpy as np import cpmpy as cp # to avoid circular import @@ -354,7 +354,7 @@ def clear(self) -> None: """ self._value = None - def __repr__(self) -> str: + def _to_string(self, str_func:Callable[[Any], str]) -> str: return self.name # for sets/dicts. Because names are unique, so is the str repr @@ -454,7 +454,7 @@ def clear(self) -> None: """ self._bv.clear() - def __repr__(self) -> str: + def _to_string(self, str_func: Callable[[Any], str]) -> str: return "~{}".format(self._bv.name) def __invert__(self) -> Expression: diff --git a/tests/test_expressions.py b/tests/test_expressions.py index c16cb226a..dfc8e55f7 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -449,10 +449,10 @@ def test_description(self): a,b = cp.boolvar(name="a"), cp.boolvar(name="b") cons = a | b - cons.set_description("either a or b should be true, but not both") + cons.set_description("either a or b should be true") assert repr(cons) == "(a) or (b)" - assert str(cons) == "either a or b should be true, but not both" + assert str(cons) == "either a or b should be true" # ensure nothing goes wrong due to calling __str__ on a constraint with a custom description for solver,cls in cp.SolverLookup.base_solvers(): @@ -465,18 +465,58 @@ def test_description(self): ## test extra attributes of set_description cons = a | b - cons.set_description("either a or b should be true, but not both", + cons.set_description("either a or b should be true", override_print=False) assert repr(cons) == "(a) or (b)" assert str(cons) == "(a) or (b)" cons = a | b - cons.set_description("either a or b should be true, but not both", + cons.set_description("either a or b should be true", full_print=True) assert repr(cons) == "(a) or (b)" - assert str(cons) == "either a or b should be true, but not both -- (a) or (b)" + assert str(cons) == "either a or b should be true -- (a) or (b)" + + def test_nested_description(self): + + a = cp.boolvar(name="a") + b = cp.boolvar(name="b") + c = cp.boolvar(name="c") + cons = a | b + cons.set_description("either a or b should be true") + assert repr(cons) == "(a) or (b)" + assert str(cons) == "either a or b should be true" + + cons2 = c.implies(cons) + + assert repr(cons2) == "(c) -> ((a) or (b))" # don't use description of nested expression + assert str(cons2) == "(c) -> (either a or b should be true)" # use description of nested expression + + + def test_expr_list_vs_ndarray(self): + + x = cp.intvar(0,10,shape=3, name=("a","b","c")) + assert isinstance(x, NDVarArray) + + cons = cp.AllDifferent(x) + assert repr(cons) == "alldifferent(a,b,c)" + assert str(cons) == "alldifferent(a,b,c)" + + cons = cp.AllDifferent(list(x)) + assert repr(cons) == "alldifferent(a,b,c)" + assert str(cons) == "alldifferent(a,b,c)" + + cons = cp.Count(x, 3) >= 1 + assert repr(cons) == "count([a, b, c],3) >= 1" + assert str(cons) == "count([a, b, c],3) >= 1" + + cons = cp.Count(list(x), 3) >= 1 + assert repr(cons) == "count([a, b, c],3) >= 1" + assert str(cons) == "count([a, b, c],3) >= 1" + + assert str(cp.Count(x,2)) == str(cp.Count(list(x),2)) + assert repr(cp.Count(x,2)) == repr(cp.Count(list(x),2)) # should be the same, functionally equivalent def test_dtype(self): diff --git a/tests/test_flatten.py b/tests/test_flatten.py index 3e19d9b78..7dd1b6c02 100644 --- a/tests/test_flatten.py +++ b/tests/test_flatten.py @@ -237,7 +237,7 @@ def test_get_or_make_var__num(self): v, cons = get_or_make_var(cp.cpm_array([1, 2, 3])[a]) assert str(v) == "IV13" - assert {str(c) for c in cons} == {"([1 2 3][IV0]) == (IV13)"} + assert {str(c) for c in cons} == {"([1, 2, 3][IV0]) == (IV13)"} v, cons = get_or_make_var(cp.cpm_array([b + c, 2, 3])[a]) assert str(v) == "IV15" @@ -269,8 +269,8 @@ def test_objective(self): assert str(flatten_objective( 2*a-3*(b - c*2) )) == '(sum([2, -3, 6] * [IV0, IV1, IV2]), [])' cp.intvar(0,2) # increase counter assert str(flatten_objective( a//b+c )) == f"((IV6) + ({str(c)}), [(({str(a)}) div ({str(b)})) == (IV6)])" - assert str(flatten_objective( cp.cpm_array([1,2,3])[a] )) == "(IV7, [([1 2 3][IV0]) == (IV7)])" - assert str(flatten_objective( cp.cpm_array([1,2,3])[a]+b )) == "((IV8) + (IV1), [([1 2 3][IV0]) == (IV8)])" + assert str(flatten_objective( cp.cpm_array([1,2,3])[a] )) == "(IV7, [([1, 2, 3][IV0]) == (IV7)])" + assert str(flatten_objective( cp.cpm_array([1,2,3])[a]+b )) == "((IV8) + (IV1), [([1, 2, 3][IV0]) == (IV8)])" def test_constraint(self): @@ -304,10 +304,10 @@ def test_constraint(self): #self.assertEqual( str(flatten_constraint( c != a + b )), "[((IV0) + (IV1)) != (IV2)]" ) # TODO, make it do the swap (again) assert str(flatten_constraint( ((a > 5) == (b < 3)) )) == "[(IV0 > 5) == (BV8), (IV1 < 3) == (BV8)]" - assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] == b )) == "[([1 2 3][IV0]) == (IV1)]" - assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] > b )) == "[([1 2 3][IV0]) > (IV1)]" + assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] == b )) == "[([1, 2, 3][IV0]) == (IV1)]" + assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] > b )) == "[([1, 2, 3][IV0]) > (IV1)]" cp.intvar(0,2, 4) # increase counter - assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] <= b )) == "[([1 2 3][IV0]) <= (IV1)]" + assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] <= b )) == "[([1, 2, 3][IV0]) <= (IV1)]" assert str(flatten_constraint( cp.AllDifferent([a+b,b+c,c+3]) )) == "[alldifferent(IV9,IV10,IV11), ((IV0) + (IV1)) == (IV9), ((IV1) + (IV2)) == (IV10), ((IV2) + 3) == (IV11)]" # issue #27 diff --git a/tests/test_globalconstraints.py b/tests/test_globalconstraints.py index 505199b9c..15269adee 100644 --- a/tests/test_globalconstraints.py +++ b/tests/test_globalconstraints.py @@ -753,9 +753,9 @@ def test_element(self): assert arr[a.value(), b.value()] == 1 # test optimization where 1 dim is index cons = iv[2, idx] == 8 - assert str(cons) == "[iv[2,0] iv[2,1] iv[2,2]][idx] == 8" + assert str(cons) == "[iv[2,0], iv[2,1], iv[2,2]][idx] == 8" cons = iv[idx, 2] == 8 - assert str(cons) == "[iv[0,2] iv[1,2] iv[2,2]][idx] == 8" + assert str(cons) == "[iv[0,2], iv[1,2], iv[2,2]][idx] == 8" def test_multid_1expr(self): @@ -763,13 +763,13 @@ def test_multid_1expr(self): a,b = cp.intvar(0,2, shape=2, name=tuple("ab")) # idx is always safe expr = x[a,1,3] - assert str(expr) == "[x[0,1,3] x[1,1,3] x[2,1,3]][a]" + assert str(expr) == "[x[0,1,3], x[1,1,3], x[2,1,3]][a]" expr = x[1,a,3] - assert str(expr) == "[x[1,0,3] x[1,1,3] x[1,2,3] x[1,3,3]][a]" + assert str(expr) == "[x[1,0,3], x[1,1,3], x[1,2,3], x[1,3,3]][a]" expr = x[1,2,a] - assert str(expr) == "[x[1,2,0] x[1,2,1] x[1,2,2] x[1,2,3] x[1,2,4]][a]" + assert str(expr) == "[x[1,2,0], x[1,2,1], x[1,2,2], x[1,2,3], x[1,2,4]][a]" def test_element_onearg(self): diff --git a/tests/test_transf_reif.py b/tests/test_transf_reif.py index 6bc206d67..682ab5ed8 100644 --- a/tests/test_transf_reif.py +++ b/tests/test_transf_reif.py @@ -93,7 +93,7 @@ def test_reif_rewrite(self): assert f((bvs[0].implies(bvs[1])).implies(rv)) == "[(~rv) -> (bvs[0]), (~rv) -> (~bvs[1])]" pytest.raises(ValueError, lambda : f(rv == cp.AllDifferent(ivs))) assert fd([rv.implies(cp.AllDifferent(ivs))]) == "[(rv) -> ((ivs[0]) != (ivs[1])), (rv) -> ((ivs[0]) != (ivs[2])), (rv) -> ((ivs[1]) != (ivs[2]))]" - assert f(rv == (arr[cp.intvar(0, 2)] != 1)) == "[([0 1 2][IV0]) == (IV1), (IV1 != 1) == (rv)]" + assert f(rv == (arr[cp.intvar(0, 2)] != 1)) == "[([0, 1, 2][IV0]) == (IV1), (IV1 != 1) == (rv)]" assert f(rv == (cp.max(ivs) > 5)) == "[(max(ivs[0],ivs[1],ivs[2])) == (IV2), (IV2 > 5) == (rv)]" assert f(rv.implies(cp.min(ivs) != 0)) == "[(min(ivs[0],ivs[1],ivs[2])) == (IV3), (rv) -> (IV3 != 0)]" assert f((cp.min(ivs) != 0).implies(rv)) == "[(min(ivs[0],ivs[1],ivs[2])) == (IV4), (IV4 != 0) -> (rv)]"