diff --git a/CHANGELOG.md b/CHANGELOG.md index 80bea8d07..1d81d54bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Unreleased ### Added - Added `addConsCumulative()` for SCIP cumulative constraints (#1222) +- `Expr` and `GenExpr` support `__pos__` magic method like `+Expr` or `+GenExpr` ### Fixed ### Changed - Move magic methods (`__radd__`, `__sub__`, `__rsub__`, `__rmul__`, `__richcmp__`, `__neg__`, and `__rtruediv__`) to `ExprLike` base class (#1204) diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index 62f0c880d..3b232fea2 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -278,6 +278,9 @@ cdef class ExprLike: def __neg__(self, /) -> Union[Expr, GenExpr]: return self * -1.0 + def __pos__(self, /) -> Union[Expr, GenExpr]: + return self.copy() + def __abs__(self) -> GenExpr: return UnaryExpr(Operator.fabs, buildGenExprObj(self)) @@ -296,6 +299,11 @@ cdef class ExprLike: def cos(self) -> GenExpr: return UnaryExpr(Operator.cos, buildGenExprObj(self)) + cdef ExprLike copy(self, bint copy=True): + raise NotImplementedError( + f"{self.__class__.__name__!s} need to implement copy() method" + ) + ##@details Polynomial expressions of variables with operator overloading. \n #See also the @ref ExprDetails "description" in the expr.pxi. @@ -435,6 +443,12 @@ cdef class Expr(ExprLike): res += coef * term._evaluate(sol) return res + cdef Expr copy(self, bint copy=True): + cdef object cls = Py_TYPE(self) + cdef Expr res = cls.__new__(cls) + res.terms = self.terms.copy() if copy else self.terms + return res + cdef class ExprCons: '''Constraints with a polynomial expressions and lower/upper bounds.''' @@ -703,18 +717,11 @@ cdef class GenExpr(ExprLike): '''returns operator of GenExpr''' return self._op - cdef GenExpr copy(self, bool copy = True): + cdef GenExpr copy(self, bint copy=True): cdef object cls = Py_TYPE(self) cdef GenExpr res = cls.__new__(cls) res._op = self._op res.children = self.children.copy() if copy else self.children - if cls is SumExpr: - (res).constant = (self).constant - (res).coefs = (self).coefs.copy() if copy else (self).coefs - if cls is ProdExpr: - (res).constant = (self).constant - elif cls is PowExpr: - (res).expo = (self).expo return res @@ -741,6 +748,14 @@ cdef class SumExpr(GenExpr): res += coefs[i] * (children[i])._evaluate(sol) return res + cdef SumExpr copy(self, bint copy=True): + cdef SumExpr res = SumExpr.__new__(SumExpr) + res._op = self._op + res.children = self.children.copy() if copy else self.children + res.constant = self.constant + res.coefs = self.coefs.copy() if copy else self.coefs + return res + # Prod Expressions cdef class ProdExpr(GenExpr): @@ -765,6 +780,13 @@ cdef class ProdExpr(GenExpr): return 0.0 return res + cdef ProdExpr copy(self, bint copy=True): + cdef ProdExpr res = ProdExpr.__new__(ProdExpr) + res._op = self._op + res.children = self.children.copy() if copy else self.children + res.constant = self.constant + return res + # Var Expressions cdef class VarExpr(GenExpr): @@ -798,6 +820,13 @@ cdef class PowExpr(GenExpr): cpdef double _evaluate(self, Solution sol) except *: return (self.children[0])._evaluate(sol) ** self.expo + cdef PowExpr copy(self, bint copy=True): + cdef PowExpr res = PowExpr.__new__(PowExpr) + res._op = self._op + res.children = self.children.copy() if copy else self.children + res.expo = self.expo + return res + # Exp, Log, Sqrt, Sin, Cos Expressions cdef class UnaryExpr(GenExpr): @@ -832,6 +861,13 @@ cdef class Constant(GenExpr): cpdef double _evaluate(self, Solution sol) except *: return self.number + cdef Constant copy(self, bint copy=True): + # The copy parameter doesn't work; this is for compatibility. + cdef Constant res = Constant.__new__(Constant) + res._op = self._op + res.number = self.number + return res + def exp(x): """ diff --git a/src/pyscipopt/scip.pxd b/src/pyscipopt/scip.pxd index 9ff2979e7..6db5be281 100644 --- a/src/pyscipopt/scip.pxd +++ b/src/pyscipopt/scip.pxd @@ -2151,7 +2151,8 @@ cdef extern from "tpi/tpi.h": int SCIPtpiGetNumThreads() cdef class ExprLike: - pass + + cdef ExprLike copy(self, bint copy=*) cdef class Expr(ExprLike): cdef public terms diff --git a/src/pyscipopt/scip.pxi b/src/pyscipopt/scip.pxi index 45f0ccf9e..8259a8ab0 100644 --- a/src/pyscipopt/scip.pxi +++ b/src/pyscipopt/scip.pxi @@ -1563,6 +1563,9 @@ cdef class Variable(Expr): cname = bytes( SCIPvarGetName(self.scip_var) ) return cname.decode('utf-8') + cdef Variable copy(self, bint copy=True): + return self + def ptr(self): return (self.scip_var) diff --git a/src/pyscipopt/scip.pyi b/src/pyscipopt/scip.pyi index 86196cfc1..aa1cff2a5 100644 --- a/src/pyscipopt/scip.pyi +++ b/src/pyscipopt/scip.pyi @@ -324,6 +324,7 @@ class Eventhdlr: def eventinit(self) -> Incomplete: ... def eventinitsol(self) -> Incomplete: ... +@disjoint_base class ExprLike: def __array_ufunc__( self, @@ -338,6 +339,7 @@ class ExprLike: def __rmul__(self, other: object, /) -> Incomplete: ... def __rtruediv__(self, other: object, /) -> GenExpr: ... def __neg__(self, /) -> Union[Expr, GenExpr]: ... + def __pos__(self, /) -> Union[Expr, GenExpr]: ... def __abs__(self) -> GenExpr: ... def exp(self) -> GenExpr: ... def log(self) -> GenExpr: ... diff --git a/tests/test_expr.py b/tests/test_expr.py index f35096f73..1b51e4f2a 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -547,3 +547,50 @@ def test_Expr_iadd_Expr(): e1 += e2 assert str(e1) == "Expr({Term(x): -1.0, Term(): 0.0, Term(y): 1.0})" assert str(e2) == "Expr({Term(y): 1.0, Term(): -1.0})" + + +def test_pos(): + m = Model() + x = m.addVar(name="x") + + # test Variable + res = +x + assert str(res) == "x" + assert res is x + + # test Expr + e = x + 1 + res = +(x + 1) + assert str(res) == "Expr({Term(x): 1.0, Term(): 1.0})" + assert e is not res + + # test SumExpr + e = sqrt(x) + 1 + res = +e + assert str(res) == str(e) + assert e is not res + + # test UnaryExpr + e = cos(x) + res = +e + assert str(res) == str(e) + assert e is not res + + # test ProdExpr + e = x * sin(x) + res = +e + assert str(res) == str(e) + assert e is not res + + # test PowExpr + e = log(x)**2 + res = +e + assert str(res) == str(e) + assert e is not res + + # test Constant + c = sqrt(1).children[0] + assert type(c) is not int + e = +c + assert str(e) == str(c) + assert e is not c