Skip to content

Commit 58574ad

Browse files
committed
Use PyNumber_Check for numeric type checks
1 parent 268abff commit 58574ad

1 file changed

Lines changed: 13 additions & 13 deletions

File tree

src/pyscipopt/expr.pxi

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ from typing import TYPE_CHECKING, Union
4848
import numpy as np
4949

5050
from cpython.dict cimport PyDict_Next, PyDict_GetItem
51+
from cpython.number cimport PyNumber_Check
5152
from cpython.object cimport Py_LE, Py_EQ, Py_GE, Py_TYPE
5253
from cpython.ref cimport PyObject
5354
from cpython.tuple cimport PyTuple_GET_ITEM
@@ -63,15 +64,15 @@ def _expr_richcmp(self: Union[Expr, GenExpr], other, int op):
6364
return NotImplemented
6465

6566
if op == Py_LE:
66-
if isinstance(other, NUMBER_TYPES):
67+
if PyNumber_Check(other):
6768
return ExprCons(self, rhs=float(other))
6869
return (self - other) <= 0.0
6970
elif op == Py_GE:
70-
if isinstance(other, NUMBER_TYPES):
71+
if PyNumber_Check(other):
7172
return ExprCons(self, lhs=float(other))
7273
return (self - other) >= 0.0
7374
elif op == Py_EQ:
74-
if isinstance(other, NUMBER_TYPES):
75+
if PyNumber_Check(other):
7576
return ExprCons(self, lhs=float(other), rhs=float(other))
7677
return (self - other) == 0.0
7778
raise NotImplementedError("can only support with '<=', '>=', or '=='")
@@ -163,7 +164,7 @@ def buildGenExprObj(expr: Union[int, float, Expr, GenExpr]) -> GenExpr:
163164
if not isinstance(expr, GENEXPR_OP_TYPES):
164165
raise TypeError(f"Unsupported type {type(expr)}")
165166

166-
if isinstance(expr, NUMBER_TYPES):
167+
if PyNumber_Check(expr):
167168
return Constant(expr)
168169

169170
elif isinstance(expr, Expr):
@@ -223,7 +224,7 @@ cdef class Expr:
223224
# merge the terms by component-wise addition
224225
for v,c in right.terms.items():
225226
terms[v] = terms.get(v, 0.0) + c
226-
elif isinstance(right, NUMBER_TYPES):
227+
elif PyNumber_Check(right):
227228
c = float(right)
228229
terms[CONST] = terms.get(CONST, 0.0) + c
229230
return Expr(terms)
@@ -235,7 +236,7 @@ cdef class Expr:
235236
if isinstance(other, Expr):
236237
for v,c in other.terms.items():
237238
self.terms[v] = self.terms.get(v, 0.0) + c
238-
elif isinstance(other, NUMBER_TYPES):
239+
elif PyNumber_Check(other):
239240
c = float(other)
240241
self.terms[CONST] = self.terms.get(CONST, 0.0) + c
241242
return self
@@ -244,7 +245,7 @@ cdef class Expr:
244245
if not isinstance(other, EXPR_OP_TYPES):
245246
return NotImplemented
246247

247-
if isinstance(other, NUMBER_TYPES):
248+
if PyNumber_Check(other):
248249
f = float(other)
249250
return Expr({v: f * c for v, c in self.terms.items()})
250251

@@ -273,7 +274,7 @@ cdef class Expr:
273274
if not isinstance(other, EXPR_OP_TYPES):
274275
return NotImplemented
275276

276-
if isinstance(other, NUMBER_TYPES):
277+
if PyNumber_Check(other):
277278
return 1.0 / other * self
278279
return buildGenExprObj(self) / other
279280

@@ -299,7 +300,7 @@ cdef class Expr:
299300
Implements base**x as scip.exp(x * scip.log(base)).
300301
Note: base must be positive.
301302
"""
302-
if not isinstance(other, NUMBER_TYPES):
303+
if not PyNumber_Check(other):
303304
raise TypeError(f"Unsupported base type {type(other)} for exponentiation.")
304305
if other <= 0.0:
305306
raise ValueError("Base of a**x must be positive, as expression is reformulated to scip.exp(x * scip.log(a)); got %g" % other)
@@ -385,7 +386,7 @@ cdef class ExprCons:
385386

386387
def __richcmp__(self, other, op):
387388
'''turn it into a constraint'''
388-
if not isinstance(other, NUMBER_TYPES):
389+
if not PyNumber_Check(other):
389390
raise TypeError('Ranged ExprCons is not well defined!')
390391

391392
if op == 1: # <=
@@ -593,7 +594,7 @@ cdef class GenExpr:
593594
Implements base**x as scip.exp(x * scip.log(base)).
594595
Note: base must be positive.
595596
"""
596-
if not isinstance(other, NUMBER_TYPES):
597+
if not PyNumber_Check(other):
597598
raise TypeError(f"Unsupported base type {type(other)} for exponentiation.")
598599
if other <= 0.0:
599600
raise ValueError("Base of a**x must be positive, as expression is reformulated to scip.exp(x * scip.log(a)); got %g" % other)
@@ -864,6 +865,5 @@ def expr_to_array(expr, nodes):
864865
return len(nodes) - 1
865866

866867

867-
cdef tuple NUMBER_TYPES = (int, float, np.number)
868-
cdef tuple EXPR_OP_TYPES = NUMBER_TYPES + (Expr,)
868+
cdef tuple EXPR_OP_TYPES = (int, float, np.number, Expr)
869869
cdef tuple GENEXPR_OP_TYPES = EXPR_OP_TYPES + (GenExpr,)

0 commit comments

Comments
 (0)