Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions code/ndarray_operators.c
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ mp_obj_t ndarray_binary_modulo(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
}
} else if(lhs->dtype == NDARRAY_UINT16) {
if(rhs->dtype == NDARRAY_UINT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
BINARY_LOOP(results, uint16_t, uint16_t, uint8_t, larray, lstrides, rarray, rstrides, %);
} else if(rhs->dtype == NDARRAY_INT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
Expand Down Expand Up @@ -1177,12 +1177,15 @@ mp_obj_t ndarray_inplace_ams(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t *rs
#if NDARRAY_HAS_INPLACE_MODULO
mp_obj_t ndarray_inplace_modulo(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t *rstrides) {
if((lhs->dtype != NDARRAY_FLOAT) && (rhs->dtype == NDARRAY_FLOAT)) {
mp_raise_TypeError(MP_ERROR_TEXT("results cannot be cast to specified type"));
mp_raise_TypeError(MP_ERROR_TEXT("cannot cast output with casting rule"));
}
uint8_t *larray = (uint8_t *)lhs->array;
uint8_t *rarray = (uint8_t *)rhs->array;
if(lhs->dtype == NDARRAY_FLOAT) {
// Float lhs: use fmod since C does not support %= for floating-point types.
if(rhs->dtype == NDARRAY_UINT8) {
INLINE_MODULO_FLOAT_LOOP(lhs, uint8_t, larray, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_UINT8) {
} else if(rhs->dtype == NDARRAY_INT8) {
INLINE_MODULO_FLOAT_LOOP(lhs, int8_t, larray, rarray, rstrides);
} else if(rhs->dtype == NDARRAY_UINT16) {
INLINE_MODULO_FLOAT_LOOP(lhs, uint16_t, larray, rarray, rstrides);
Expand All @@ -1191,6 +1194,46 @@ mp_obj_t ndarray_inplace_modulo(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t
} else {
INLINE_MODULO_FLOAT_LOOP(lhs, mp_float_t, larray, rarray, rstrides);
}
} else if(lhs->dtype == NDARRAY_UINT8) {
if(rhs->dtype == NDARRAY_UINT8) {
INPLACE_LOOP(lhs, uint8_t, uint8_t, larray, rarray, rstrides, %=);
} else if(rhs->dtype == NDARRAY_INT8) {
INPLACE_LOOP(lhs, uint8_t, int8_t, larray, rarray, rstrides, %=);
} else if(rhs->dtype == NDARRAY_UINT16) {
INPLACE_LOOP(lhs, uint8_t, uint16_t, larray, rarray, rstrides, %=);
} else {
INPLACE_LOOP(lhs, uint8_t, int16_t, larray, rarray, rstrides, %=);
}
} else if(lhs->dtype == NDARRAY_INT8) {
if(rhs->dtype == NDARRAY_UINT8) {
INPLACE_LOOP(lhs, int8_t, uint8_t, larray, rarray, rstrides, %=);
} else if(rhs->dtype == NDARRAY_INT8) {
INPLACE_LOOP(lhs, int8_t, int8_t, larray, rarray, rstrides, %=);
} else if(rhs->dtype == NDARRAY_UINT16) {
INPLACE_LOOP(lhs, int8_t, uint16_t, larray, rarray, rstrides, %=);
} else {
INPLACE_LOOP(lhs, int8_t, int16_t, larray, rarray, rstrides, %=);
}
} else if(lhs->dtype == NDARRAY_UINT16) {
if(rhs->dtype == NDARRAY_UINT8) {
INPLACE_LOOP(lhs, uint16_t, uint8_t, larray, rarray, rstrides, %=);
} else if(rhs->dtype == NDARRAY_INT8) {
INPLACE_LOOP(lhs, uint16_t, int8_t, larray, rarray, rstrides, %=);
} else if(rhs->dtype == NDARRAY_UINT16) {
INPLACE_LOOP(lhs, uint16_t, uint16_t, larray, rarray, rstrides, %=);
} else {
INPLACE_LOOP(lhs, uint16_t, int16_t, larray, rarray, rstrides, %=);
}
} else if(lhs->dtype == NDARRAY_INT16) {
if(rhs->dtype == NDARRAY_UINT8) {
INPLACE_LOOP(lhs, int16_t, uint8_t, larray, rarray, rstrides, %=);
} else if(rhs->dtype == NDARRAY_INT8) {
INPLACE_LOOP(lhs, int16_t, int8_t, larray, rarray, rstrides, %=);
} else if(rhs->dtype == NDARRAY_UINT16) {
INPLACE_LOOP(lhs, int16_t, uint16_t, larray, rarray, rstrides, %=);
} else {
INPLACE_LOOP(lhs, int16_t, int16_t, larray, rarray, rstrides, %=);
}
}
return MP_OBJ_FROM_PTR(lhs);
}
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@
#endif

#ifndef NDARRAY_HAS_INPLACE_MODULO
#define NDARRAY_HAS_INPLACE_MODU (1)
#define NDARRAY_HAS_INPLACE_MODULO (1)
#endif

#ifndef NDARRAY_HAS_INPLACE_MULTIPLY
Expand Down
44 changes: 42 additions & 2 deletions tests/2d/numpy/modulo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,49 @@
print('=' * 30)
print()

# integer lhs %= float rhs raises TypeError (cannot store float result in integer array),
# matching NumPy's casting semantics for in-place operations.
for dtype1 in dtypes:
x1 = np.array(range(6), dtype=dtype1).reshape((2, 3))
for dtype2 in dtypes:
x2 = np.array(range(1, 4), dtype=dtype2)
x1 %= x2
print(x1)
try:
x1 %= x2
print(x1)
except TypeError:
print('TypeError')

print()
print('=' * 30)
print('regression tests')
print('=' * 30)
print()

# Bug: binary uint16 % uint8 was returning dtype=uint8 with corrupted values.
# Result dtype must be uint16 and values must be correct.
a = np.array([0, 1, 2, 3, 4, 5], dtype=np.uint16)
b = np.array([1, 2, 3, 1, 2, 3], dtype=np.uint8)
r = a % b
print(r)

# Bug: inplace_modulo was disabled by a typo in the feature-guard macro
# (NDARRAY_HAS_INPLACE_MODU instead of NDARRAY_HAS_INPLACE_MODULO), so %=
# silently fell back to binary % for all types. Verify it now works in-place
# for integer types without changing the array dtype.
a = np.array([3, 4, 5], dtype=np.uint8)
a %= np.array([2, 3, 4], dtype=np.uint8)
print(a)

a = np.array([3, 4, 5], dtype=np.int8)
a %= np.array([2, 3, 4], dtype=np.int8)
print(a)

a = np.array([3, 4, 5], dtype=np.uint16)
a %= np.array([2, 3, 4], dtype=np.uint8)
print(a)

# Bug: inplace float %= int8 was a no-op because the second branch checked
# NDARRAY_UINT8 again instead of NDARRAY_INT8. Verify fmod is applied.
a = np.array([3.5, 4.5, 5.5])
a %= np.array([2, 3, 4], dtype=np.int8)
print(a)
68 changes: 37 additions & 31 deletions tests/2d/numpy/modulo.py.exp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ array([[0, 1, 2],
[0, 0, 2]], dtype=int16)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0, 0, 1],
[0, 2, 0]], dtype=uint8)
array([[0, 1, 2],
[0, 0, 2]], dtype=uint16)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0, 1, 2],
Expand Down Expand Up @@ -56,37 +56,39 @@ inplace modulo
array([[0, 1, 2],
[0, 0, 2]], dtype=uint8)
array([[0, 1, 2],
[0, 0, 2]], dtype=int16)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
[0, 0, 2]], dtype=uint8)
array([[0, 1, 2],
[0, 0, 2]], dtype=uint8)
array([[0, 1, 2],
[0, 0, 2]], dtype=uint8)
TypeError
array([[0, 1, 2],
[0, 0, 2]], dtype=int8)
array([[0, 1, 2],
[0, 0, 2]], dtype=int8)
array([[0, 1, 2],
[0, 0, 2]], dtype=int8)
array([[0, 1, 2],
[0, 0, 2]], dtype=int8)
TypeError
array([[0, 1, 2],
[0, 0, 2]], dtype=uint16)
array([[0, 1, 2],
[0, 0, 2]], dtype=uint16)
array([[0, 1, 2],
[0, 0, 2]], dtype=uint16)
array([[0, 1, 2],
[0, 0, 2]], dtype=uint16)
TypeError
array([[0, 1, 2],
[0, 0, 2]], dtype=int16)
array([[0, 1, 2],
[0, 0, 2]], dtype=int16)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0, 0, 1],
[0, 2, 0]], dtype=uint8)
array([[0, 0, 1],
[0, 0, 0]], dtype=int16)
array([[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0]], dtype=float64)
array([[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0]], dtype=float64)
array([[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0]], dtype=float64)
array([[0, 1, 2],
[0, 0, 2]], dtype=int16)
array([[0, 1, 2],
[0, 0, 2]], dtype=int16)
TypeError
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0.0, 1.0, 2.0],
Expand All @@ -97,9 +99,13 @@ array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)
array([[0.0, 1.0, 2.0],
[0.0, 0.0, 2.0]], dtype=float64)

==============================
regression tests
==============================

array([0, 1, 2, 0, 0, 2], dtype=uint16)
array([1, 1, 1], dtype=uint8)
array([1, 1, 1], dtype=int8)
array([1, 1, 1], dtype=uint16)
array([1.5, 1.5, 1.5], dtype=float64)
Loading