Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions cpmpy/expressions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
import copy
import warnings
from dataclasses import dataclass
from typing import Any, Final, Optional, TypeAlias, TypeVar, Union, Sequence, Iterable
from typing import Any, Callable, Final, Optional, TypeAlias, TypeVar, Union, Sequence, Iterable
from frozendict import frozendict
import numpy as np
import cpmpy as cp
Expand Down Expand Up @@ -168,26 +168,28 @@ def update_args(self, args: Iterable[Any], has_subexpr: Optional[bool] = None) -
def set_description(self, txt: str, override_print: bool = True, full_print: bool = False) -> None:
self._description = Description(txt, override_print, full_print)

def _to_string(self, str_func: Callable[[Any], str]) -> str:
strargs = []
for arg in self.args:
if isinstance(arg, np.ndarray):
# flatten
strarg = ", ".join(map(str_func, arg.flat)) # with space to match list printing
strargs.append(f"[{strarg}]")
else:
strargs.append(str_func(arg))
return "{}({})".format(self.name, ",".join(strargs))

def __str__(self) -> str:
d = self._description
if d is None or not d.override_print:
return self.__repr__()
return self._to_string(str_func=str)
out = d.text
if d.full_print:
out += " -- " + self.__repr__()
return out


def __repr__(self) -> str:
strargs = []
for arg in self.args:
if isinstance(arg, np.ndarray):
# flatten
strarg = ",".join(map(str, arg.flat))
strargs.append(f"[{strarg}]")
else:
strargs.append(f"{arg}")
return "{}({})".format(self.name, ",".join(strargs))
return self._to_string(str_func=repr)

def __hash__(self) -> int:
return hash(self.__repr__())
Expand Down Expand Up @@ -612,11 +614,11 @@ def __init__(self, name: str, left: ExprLike, right: ExprLike) -> None:
assert (name in Comparison.allowed), f"Symbol {name} not allowed"
super().__init__(name, (left, right))

def __repr__(self) -> str:
def _to_string(self, str_func: Callable[[Any], str]) -> str:
if all(isinstance(x, Expression) for x in self.args):
return "({}) {} ({})".format(self.args[0], self.name, self.args[1])
return "({}) {} ({})".format(str_func(self.args[0]), self.name, str_func(self.args[1]))
# if not: prettier printing without braces
return "{} {} {}".format(self.args[0], self.name, self.args[1])
return "{} {} {}".format(str_func(self.args[0]), self.name, str_func(self.args[1]))

def __bool__(self) -> bool:
# will be called when comparing elements in a container, but always with `==`
Expand Down Expand Up @@ -729,26 +731,26 @@ def is_bool(self) -> bool:
"""
return Operator.allowed[self.name][1]

def __repr__(self) -> str:
def _to_string(self, str_func: Callable[[Any], str]) -> str:

# special cases
if self.name == '-': # unary -
return "-({})".format(self.args[0])
return "-({})".format(str_func(self.args[0]))

# weighted sum
if self.name == 'wsum':
return f"sum({self.args[0]} * {self.args[1]})"
return f"sum({self.args[0]} * {str_func(self.args[1])})"

if len(self.args) == 1:
return "{}({})".format(self.name, self.args[0]) # tuple of size 1 omitted in print
return "{}({})".format(self.name, str_func(self.args[0])) # tuple of size 1 omitted in print
elif len(self.args) == 2: # infix printing of two arguments
printname = Operator.printmap.get(self.name, self.name) # default to self.name if not in printmap
arg0, arg1 = self.args
str_arg0 = f"({arg0})" if isinstance(arg0, Expression) else str(arg0)
str_arg1 = f"({arg1})" if isinstance(arg1, Expression) else str(arg1)
str_arg0 = f"({str_func(arg0)})" if isinstance(arg0, Expression) else str(arg0)
str_arg1 = f"({str_func(arg1)})" if isinstance(arg1, Expression) else str(arg1)
return f"{str_arg0} {printname} {str_arg1}"
else: # n-ary
return "{}{}".format(self.name, self.args) # args is a tuple, will be in ()
return "{}{}".format(self.name, str_func(self.args)) # args is a tuple, will be in ()

def value(self) -> Optional[int]:
"""
Expand Down
16 changes: 8 additions & 8 deletions cpmpy/expressions/globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def my_circuit_decomp(self):

"""
import warnings
from typing import cast, Literal, Optional, Iterable, Any, TYPE_CHECKING
from typing import Callable, cast, Literal, Optional, Iterable, Any, TYPE_CHECKING
import numpy as np

import cpmpy as cp
Expand Down Expand Up @@ -871,9 +871,9 @@ def decompose(self) -> tuple[list[Expression], list[Expression]]:
condition = cp.BoolVal(condition) # ensure it is a CPMpy expression
return [condition.implies(if_true), (~condition).implies(if_false)], []

def __repr__(self) -> str:
def _to_string(self, str_func: Callable[[Any], str]) -> str:
condition, if_true, if_false = self.args
return "If {} Then {} Else {}".format(condition, if_true, if_false)
return "If {} Then {} Else {}".format(str_func(condition), str_func(if_true), str_func(if_false))

def negate(self) -> Expression:
return IfThenElse(self.args[0], self.args[2], self.args[1])
Expand Down Expand Up @@ -929,9 +929,9 @@ def value(self) -> Optional[bool]:
return None
return bool(np.any(arr == exprval))

def __repr__(self) -> str:
def _to_string(self, str_func: Callable[[Any], str]) -> str:
expr, arr = self.args
return "{} in {}".format(expr, arr)
return "{} in {}".format(str_func(expr), str_func(arr))

def negate(self) -> Expression:
expr, arr = self.args
Expand Down Expand Up @@ -1009,10 +1009,10 @@ def value(self) -> Optional[bool]:
return None
return sum(arrvals) % 2 == 1

def __repr__(self) -> str:
def _to_string(self, str_func: Callable[[Any], str]) -> str:
if len(self.args) == 2:
return "{} xor {}".format(*self.args)
return "xor({})".format(self.args)
return "{} xor {}".format(str_func(self.args[0]), str_func(self.args[1]))
return "xor({})".format(str_func(self.args))

def negate(self) -> Expression:
# negate one of the arguments, ideally a variable
Expand Down
32 changes: 19 additions & 13 deletions cpmpy/expressions/globalfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ def decompose(self):

"""
import warnings # for deprecation warning
from typing import Optional, Iterable
from typing import Any, Callable, Optional, Iterable
import numpy as np
import cpmpy as cp

from ..exceptions import CPMpyException, IncompleteFunctionError, TypeError
from .core import Expression, Operator, ExprLike, ListLike
from .variables import intvar, NDVarArray, _NumVarImpl, BoolVal
from .variables import cpm_array, intvar, NDVarArray, _NumVarImpl, BoolVal
from .utils import argval, is_num, eval_comparison, is_any_list, is_boolexpr, get_bounds, argvals, implies, argvals_intexpr, get_bounds_intexpr, npint2int


Expand Down Expand Up @@ -379,13 +379,13 @@ def update_args(self, args):
super().update_args((x, y))
self.is_lhs_num = is_lhs_num

def __repr__(self):
def _to_string(self, str_func: Callable[[Any], str]) -> str:
x, y = self.args

if self.is_lhs_num:
return "{} * ({})".format(x, y)
return "{} * ({})".format(str_func(x), str_func(y))

return "({}) * ({})".format(x, y)
return "({}) * ({})".format(str_func(x), str_func(y))

def __neg__(self):
"""-(c*x) -> (-c)*x when constant c is first (.is_lhs_num)."""
Expand Down Expand Up @@ -481,14 +481,14 @@ def __init__(self, x: ExprLike, y: ExprLike):
"""
super().__init__("div", (x, y))

def __repr__(self):
def _to_string(self, str_func: Callable[[Any], str]) -> str:
"""
Returns:
str: String representation of integer division as 'x div y'
"""
x,y = self.args
return "{} div {}".format(f"({x})" if isinstance(x, Expression) else x,
f"({y})" if isinstance(y, Expression) else y)
return "{} div {}".format(f"({str_func(x)})" if isinstance(x, Expression) else str_func(x),
f"({str_func(y)})" if isinstance(y, Expression) else str_func(y))

def decompose(self):
"""
Expand Down Expand Up @@ -576,14 +576,14 @@ def __init__(self, x: ExprLike, y: ExprLike):
"""
super().__init__("mod", (x, y))

def __repr__(self):
def _to_string(self, str_func: Callable[[Any], str]) -> str:
"""
Returns:
str: String representation with 'mod' as notation
"""
x,y = self.args
return "{} mod {}".format(f"({x})" if isinstance(x, Expression) else x,
f"({y})" if isinstance(y, Expression) else y)
return "{} mod {}".format(f"({str_func(x)})" if isinstance(x, Expression) else str_func(x),
f"({str_func(y)})" if isinstance(y, Expression) else str_func(y))

def decompose(self):
"""
Expand Down Expand Up @@ -814,14 +814,19 @@ def get_bounds(self) -> tuple[int, int]:
bnds = [get_bounds(x) for x in arr]
return min(lb for lb,ub in bnds), max(ub for lb,ub in bnds)

def __repr__(self) -> str:
def _to_string(self, str_func: Callable[[Any], str]) -> str:
"""
Custom string representation of the Element global function in 'Arr[Idx]' format.

Returns:
str: String representation of the Element global function.
"""
return f"{self.args[0]}[{self.args[1]}]"
arr, idx = self.args
if isinstance(arr, np.ndarray):
str_arr = str_func(arr.tolist()) # overkill, also converts np int to Python int
else:
str_arr = str_func(arr)
return f"{str_arr}[{str_func(idx)}]"

def element(arg_list):
"""
Expand Down Expand Up @@ -853,6 +858,7 @@ def __init__(self, arr: ListLike[ExprLike], val: ExprLike):
raise TypeError(f"Count(arr, val) takes an array of expressions as first argument, not: {arr}")
if is_any_list(val):
raise TypeError(f"Count(arr, val) takes a numeric expression as second argument, not a list: {val}")
arr = cpm_array(arr)
super().__init__("count", (arr, val))

def decompose(self) -> tuple[Expression, list[Expression]]:
Expand Down
6 changes: 3 additions & 3 deletions cpmpy/expressions/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from collections.abc import Iterable
import warnings # for deprecation warning
from functools import reduce
from typing import Any, Literal, Optional, overload
from typing import Any, Callable, Literal, Optional, overload

import numpy as np
import cpmpy as cp # to avoid circular import
Expand Down Expand Up @@ -354,7 +354,7 @@ def clear(self) -> None:
"""
self._value = None

def __repr__(self) -> str:
def _to_string(self, str_func:Callable[[Any], str]) -> str:
return self.name

# for sets/dicts. Because names are unique, so is the str repr
Expand Down Expand Up @@ -454,7 +454,7 @@ def clear(self) -> None:
"""
self._bv.clear()

def __repr__(self) -> str:
def _to_string(self, str_func: Callable[[Any], str]) -> str:
return "~{}".format(self._bv.name)

def __invert__(self) -> Expression:
Expand Down
50 changes: 45 additions & 5 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,10 @@ def test_description(self):

a,b = cp.boolvar(name="a"), cp.boolvar(name="b")
cons = a | b
cons.set_description("either a or b should be true, but not both")
cons.set_description("either a or b should be true")

assert repr(cons) == "(a) or (b)"
assert str(cons) == "either a or b should be true, but not both"
assert str(cons) == "either a or b should be true"

# ensure nothing goes wrong due to calling __str__ on a constraint with a custom description
for solver,cls in cp.SolverLookup.base_solvers():
Expand All @@ -465,18 +465,58 @@ def test_description(self):

## test extra attributes of set_description
cons = a | b
cons.set_description("either a or b should be true, but not both",
cons.set_description("either a or b should be true",
override_print=False)

assert repr(cons) == "(a) or (b)"
assert str(cons) == "(a) or (b)"

cons = a | b
cons.set_description("either a or b should be true, but not both",
cons.set_description("either a or b should be true",
full_print=True)

assert repr(cons) == "(a) or (b)"
assert str(cons) == "either a or b should be true, but not both -- (a) or (b)"
assert str(cons) == "either a or b should be true -- (a) or (b)"

def test_nested_description(self):

a = cp.boolvar(name="a")
b = cp.boolvar(name="b")
c = cp.boolvar(name="c")
cons = a | b
cons.set_description("either a or b should be true")
assert repr(cons) == "(a) or (b)"
assert str(cons) == "either a or b should be true"

cons2 = c.implies(cons)

assert repr(cons2) == "(c) -> ((a) or (b))" # don't use description of nested expression
assert str(cons2) == "(c) -> (either a or b should be true)" # use description of nested expression


def test_expr_list_vs_ndarray(self):

x = cp.intvar(0,10,shape=3, name=("a","b","c"))
assert isinstance(x, NDVarArray)

cons = cp.AllDifferent(x)
assert repr(cons) == "alldifferent(a,b,c)"
assert str(cons) == "alldifferent(a,b,c)"

cons = cp.AllDifferent(list(x))
assert repr(cons) == "alldifferent(a,b,c)"
assert str(cons) == "alldifferent(a,b,c)"

cons = cp.Count(x, 3) >= 1
assert repr(cons) == "count([a, b, c],3) >= 1"
assert str(cons) == "count([a, b, c],3) >= 1"

cons = cp.Count(list(x), 3) >= 1
assert repr(cons) == "count([a, b, c],3) >= 1"
assert str(cons) == "count([a, b, c],3) >= 1"

assert str(cp.Count(x,2)) == str(cp.Count(list(x),2))
assert repr(cp.Count(x,2)) == repr(cp.Count(list(x),2)) # should be the same, functionally equivalent


def test_dtype(self):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_get_or_make_var__num(self):

v, cons = get_or_make_var(cp.cpm_array([1, 2, 3])[a])
assert str(v) == "IV13"
assert {str(c) for c in cons} == {"([1 2 3][IV0]) == (IV13)"}
assert {str(c) for c in cons} == {"([1, 2, 3][IV0]) == (IV13)"}

v, cons = get_or_make_var(cp.cpm_array([b + c, 2, 3])[a])
assert str(v) == "IV15"
Expand Down Expand Up @@ -269,8 +269,8 @@ def test_objective(self):
assert str(flatten_objective( 2*a-3*(b - c*2) )) == '(sum([2, -3, 6] * [IV0, IV1, IV2]), [])'
cp.intvar(0,2) # increase counter
assert str(flatten_objective( a//b+c )) == f"((IV6) + ({str(c)}), [(({str(a)}) div ({str(b)})) == (IV6)])"
assert str(flatten_objective( cp.cpm_array([1,2,3])[a] )) == "(IV7, [([1 2 3][IV0]) == (IV7)])"
assert str(flatten_objective( cp.cpm_array([1,2,3])[a]+b )) == "((IV8) + (IV1), [([1 2 3][IV0]) == (IV8)])"
assert str(flatten_objective( cp.cpm_array([1,2,3])[a] )) == "(IV7, [([1, 2, 3][IV0]) == (IV7)])"
assert str(flatten_objective( cp.cpm_array([1,2,3])[a]+b )) == "((IV8) + (IV1), [([1, 2, 3][IV0]) == (IV8)])"


def test_constraint(self):
Expand Down Expand Up @@ -304,10 +304,10 @@ def test_constraint(self):
#self.assertEqual( str(flatten_constraint( c != a + b )), "[((IV0) + (IV1)) != (IV2)]" ) # TODO, make it do the swap (again)
assert str(flatten_constraint( ((a > 5) == (b < 3)) )) == "[(IV0 > 5) == (BV8), (IV1 < 3) == (BV8)]"

assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] == b )) == "[([1 2 3][IV0]) == (IV1)]"
assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] > b )) == "[([1 2 3][IV0]) > (IV1)]"
assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] == b )) == "[([1, 2, 3][IV0]) == (IV1)]"
assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] > b )) == "[([1, 2, 3][IV0]) > (IV1)]"
cp.intvar(0,2, 4) # increase counter
assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] <= b )) == "[([1 2 3][IV0]) <= (IV1)]"
assert str(flatten_constraint( cp.cpm_array([1,2,3])[a] <= b )) == "[([1, 2, 3][IV0]) <= (IV1)]"
assert str(flatten_constraint( cp.AllDifferent([a+b,b+c,c+3]) )) == "[alldifferent(IV9,IV10,IV11), ((IV0) + (IV1)) == (IV9), ((IV1) + (IV2)) == (IV10), ((IV2) + 3) == (IV11)]"

# issue #27
Expand Down
Loading
Loading