Skip to content

Commit d317543

Browse files
committed
update permutation and multinomial implementations
1 parent 5574ceb commit d317543

File tree

1 file changed

+60
-18
lines changed

1 file changed

+60
-18
lines changed

mkl_random/mklrand.pyx

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6303,29 +6303,55 @@ cdef class _MKLRandomState:
63036303
array([100, 0])
63046304
63056305
"""
6306-
cdef cnp.npy_intp d
6306+
cdef cnp.npy_intp d, sz, niter
63076307
cdef cnp.ndarray parr "arrayObject_parr", mnarr "arrayObject_mnarr"
63086308
cdef double *pix
63096309
cdef int *mnix
6310-
cdef cnp.npy_intp sz
6311-
6312-
d = len(pvals)
6313-
parr = <cnp.ndarray>cnp.PyArray_ContiguousFromObject(
6314-
pvals, cnp.NPY_DOUBLE, 1, 1
6310+
cdef long ni
6311+
6312+
parr = <cnp.ndarray>cnp.PyArray_FROMANY(
6313+
pvals,
6314+
cnp.NPY_DOUBLE,
6315+
0,
6316+
1,
6317+
cnp.NPY_ARRAY_ALIGNED | cnp.NPY_ARRAY_C_CONTIGUOUS
63156318
)
6319+
if cnp.PyArray_NDIM(parr) == 0:
6320+
raise TypeError("pvals must be a 1-d sequence")
6321+
d = cnp.PyArray_SIZE(parr)
63166322
pix = <double*>cnp.PyArray_DATA(parr)
6317-
6318-
if kahan_sum(pix, d-1) > (1.0 + 1e-12):
6319-
raise ValueError("sum(pvals[:-1]) > 1.0")
6320-
6323+
if (
6324+
not np.all(np.greater_equal(parr, 0))
6325+
or not np.all(np.less_equal(parr, 1))
6326+
):
6327+
raise ValueError("pvals < 0, pvals > 1 or pvals is NaN")
6328+
6329+
if d and kahan_sum(pix, d - 1) > (1.0 + 1e-12):
6330+
# When floating, but not float dtype, and close, improve the error
6331+
# 1.0001 works for float16 and float32
6332+
if (isinstance(pvals, np.ndarray)
6333+
and np.issubdtype(pvals.dtype, np.floating)
6334+
and pvals.dtype != float
6335+
and pvals.sum() < 1.0001):
6336+
msg = ("sum(pvals[:-1].astype(np.float64)) > 1.0. The pvals "
6337+
"array is cast to 64-bit floating point prior to "
6338+
"checking the sum. Precision changes when casting may "
6339+
"cause problems even if the sum of the original pvals "
6340+
"is valid.")
6341+
else:
6342+
msg = "sum(pvals[:-1]) > 1.0"
6343+
raise ValueError(msg)
63216344
shape = _shape_from_size(size, d)
63226345
multin = np.zeros(shape, np.int32)
6323-
63246346
mnarr = <cnp.ndarray>multin
63256347
mnix = <int*>cnp.PyArray_DATA(mnarr)
63266348
sz = cnp.PyArray_SIZE(mnarr)
6327-
6328-
irk_multinomial_vec(self.internal_state, sz // d, mnix, n, d, pix)
6349+
ni = n
6350+
if (ni < 0):
6351+
raise ValueError("n < 0")
6352+
# numpy#20483: Avoids divide by 0
6353+
niter = sz // d if d else 0
6354+
irk_multinomial_vec(self.internal_state, niter, mnix, n, d, pix)
63296355

63306356
return multin
63316357

@@ -6614,11 +6640,27 @@ cdef class _MKLRandomState:
66146640
66156641
"""
66166642
if isinstance(x, (int, np.integer)):
6617-
arr = np.arange(x)
6618-
else:
6619-
arr = np.array(x)
6620-
self.shuffle(arr)
6621-
return arr
6643+
# keep using long as the default here (main numpy switched to intp)
6644+
arr = np.arange(x, dtype=np.result_type(x, np.long))
6645+
self.shuffle(arr)
6646+
return arr
6647+
6648+
arr = np.asarray(x)
6649+
if arr.ndim < 1:
6650+
raise IndexError("x must be an integer or at least 1-dimensional")
6651+
6652+
# shuffle has fast-path for 1-d
6653+
if arr.ndim == 1:
6654+
# Return a copy if same memory
6655+
if np.may_share_memory(arr, x):
6656+
arr = np.array(arr)
6657+
self.shuffle(arr)
6658+
return arr
6659+
6660+
# Shuffle index array, dtype to ensure fast path
6661+
idx = np.arange(arr.shape[0], dtype=np.intp)
6662+
self.shuffle(idx)
6663+
return arr[idx]
66226664

66236665

66246666
cdef class MKLRandomState(_MKLRandomState):

0 commit comments

Comments
 (0)