Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,3 @@ skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = true
docstring-code-line-length = "dynamic"


1 change: 1 addition & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ._arraykit import array_to_tuple_array as array_to_tuple_array
from ._arraykit import array_to_tuple_iter as array_to_tuple_iter
from ._arraykit import nonzero_1d as nonzero_1d
from ._arraykit import transition_slices_from_group as transition_slices_from_group
from ._arraykit import is_objectable_dt64 as is_objectable_dt64
from ._arraykit import is_objectable as is_objectable
from ._arraykit import astype_array as astype_array
Expand Down
3 changes: 3 additions & 0 deletions src/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def write_array_to_file(
def first_true_1d(__array: np.ndarray, *, forward: bool) -> int: ...
def first_true_2d(__array: np.ndarray, *, forward: bool, axis: int) -> np.ndarray: ...
def nonzero_1d(__array: np.ndarray, /) -> np.ndarray: ...
def transition_slices_from_group(
__group: np.ndarray, /
) -> tp.Tuple[tp.Iterator[slice], bool]: ...
def is_objectable_dt64(__array: np.ndarray, /) -> bool: ...
def is_objectable(__array: np.ndarray, /) -> bool: ...
def astype_array(__array: np.ndarray, __dtype: np.dtype | None, /) -> np.ndarray: ...
Expand Down
1 change: 1 addition & 0 deletions src/_arraykit.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ static PyMethodDef arraykit_methods[] = {
NULL},
{"count_iteration", count_iteration, METH_O, NULL},
{"nonzero_1d", nonzero_1d, METH_O, NULL},
{"transition_slices_from_group", transition_slices_from_group, METH_O, NULL},
{"is_objectable_dt64", is_objectable_dt64, METH_O, NULL},
{"is_objectable", is_objectable, METH_O, NULL},
{"astype_array", astype_array, METH_VARARGS, NULL},
Expand Down
326 changes: 326 additions & 0 deletions src/methods.c
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,332 @@ nonzero_1d(PyObject *Py_UNUSED(m), PyObject *a) {
return AK_nonzero_1d(array);
}

static inline int
AK_append_transition_slice(PyObject* slices, npy_intp start, npy_intp stop)
{
PyObject* py_start = PyLong_FromSsize_t(start);
if (!py_start) {
return -1;
}

PyObject* py_stop = NULL;
if (stop == -1) {
py_stop = Py_None;
Py_INCREF(py_stop);
}
else {
py_stop = PyLong_FromSsize_t(stop);
}
if (!py_stop) {
Py_DECREF(py_start);
return -1;
}

PyObject* slc = PySlice_New(py_start, py_stop, Py_None);
Py_DECREF(py_start);
Py_DECREF(py_stop);
if (!slc) {
return -1;
}
int append_result = PyList_Append(slices, slc);
Py_DECREF(slc);
return append_result;
}

// Return 1 if the two elements are not equal (a transition), 0 if equal, -1 on
// error. Equality matches numpy's elementwise `!=`: byte-exact for integral,
// string, and void dtypes; IEEE-aware for floats (NaN != NaN, +0.0 == -0.0);
// and NaT-aware for datetime64/timedelta64 (NaT != everything, including NaT).
// Uncommon dtypes (complex, float16, longdouble, object) fall back to a boxed
// Python comparison, which is correct though slower.
static inline int
AK_values_not_equal_1d(PyArrayObject* array, npy_intp left, npy_intp right)
{
npy_intp stride = PyArray_STRIDE(array, 0);
char* p_left = PyArray_BYTES(array) + (left * stride);
char* p_right = PyArray_BYTES(array) + (right * stride);

switch (PyArray_TYPE(array)) {
// For these dtypes numpy `!=` is exactly a byte comparison, so memcmp
// is both correct and the fastest option (no Python objects created).
case NPY_BOOL:
case NPY_BYTE: case NPY_UBYTE:
case NPY_SHORT: case NPY_USHORT:
case NPY_INT: case NPY_UINT:
case NPY_LONG: case NPY_ULONG:
case NPY_LONGLONG: case NPY_ULONGLONG:
case NPY_STRING: case NPY_UNICODE: case NPY_VOID:
return memcmp(p_left, p_right, (size_t)PyArray_ITEMSIZE(array)) != 0;

// Floats need IEEE semantics, which a byte compare would get wrong for
// NaN (distinct bits compare equal under ==) and signed zero.
case NPY_FLOAT: {
float a, b;
memcpy(&a, p_left, sizeof a);
memcpy(&b, p_right, sizeof b);
return a != b;
}
case NPY_DOUBLE: {
double a, b;
memcpy(&a, p_left, sizeof a);
memcpy(&b, p_right, sizeof b);
return a != b;
}

// datetime64/timedelta64 are int64 with NaT == INT64_MIN;
case NPY_DATETIME:
case NPY_TIMEDELTA: {
npy_int64 a, b;
memcpy(&a, p_left, sizeof a);
memcpy(&b, p_right, sizeof b);
if (a == NPY_DATETIME_NAT || b == NPY_DATETIME_NAT) {
return 1;
}
return a != b;
}
}

// Object dtype and any unhandled type: compare via Python objects.
PyObject* left_value = PyArray_GETITEM(array, p_left);
if (!left_value) {
return -1;
}
PyObject* right_value = PyArray_GETITEM(array, p_right);
if (!right_value) {
Py_DECREF(left_value);
return -1;
}
int is_not_equal = PyObject_RichCompareBool(left_value, right_value, Py_NE);
Py_DECREF(left_value);
Py_DECREF(right_value);
return is_not_equal;
}

// Compare two rows of a 2d array. Object arrays are compared cell-by-cell with
// rich equality, consistent with the 1d object path (AK_values_not_equal_1d);
// all other dtypes are compared bytewise, matching numpy's void-view equality.
static inline int
AK_rows_equal_2d(PyArrayObject* array, npy_intp left, npy_intp right)
{
if (PyArray_TYPE(array) == NPY_OBJECT) {
npy_intp cols = PyArray_DIM(array, 1);
for (npy_intp col = 0; col < cols; ++col) {
PyObject* left_value = *(PyObject**)PyArray_GETPTR2(array, left, col);
PyObject* right_value = *(PyObject**)PyArray_GETPTR2(array, right, col);
int is_equal = PyObject_RichCompareBool(left_value, right_value, Py_EQ);
if (is_equal <= 0) {
Comment thread
flexatone marked this conversation as resolved.
return is_equal; // 0 (not equal) or -1 (error)
}
}
return 1;
}

npy_intp stride_0 = PyArray_STRIDE(array, 0);
npy_intp stride_1 = PyArray_STRIDE(array, 1);
npy_intp cols = PyArray_DIM(array, 1);
npy_intp itemsize = PyArray_ITEMSIZE(array);

char* p_left = PyArray_BYTES(array) + (left * stride_0);
char* p_right = PyArray_BYTES(array) + (right * stride_0);

if (stride_1 == itemsize) {
return memcmp(p_left, p_right, (size_t)(cols * itemsize)) == 0;
}

for (npy_intp col = 0; col < cols; ++col) {
npy_intp offset = col * stride_1;
if (memcmp(p_left + offset, p_right + offset, (size_t)itemsize) != 0) {
return 0;
}
}
return 1;
}

// Scan a contiguous, monomorphic 1d buffer for transitions. EQ tests whether
// element i equals element i-1; the inner loop runs over equal runs with no
// function call so the compiler can keep it tight (and vectorize the compare).
// On a transition, emit the run [start, i) and continue. `goto finalize` after.
#define AK_SCAN_TRANSITIONS(EQ) \
do { \
npy_intp i = 1; \
while (i < size) { \
while (i < size && (EQ)) { ++i; } \
if (i >= size) { break; } \
if (AK_append_transition_slice(slices, start, i)) { return -1; } \
start = i; \
++i; \
} \
} while (0)

// Append transition slices for a 1d array to `slices`. Returns 0 on success,
// -1 on error. Common contiguous dtypes use a hoisted, typed scan; everything
// else (non-contiguous, strings, void, complex, half, longdouble, object, or
// unusual widths) falls back to the per-element comparison.
static int
AK_fill_transition_slices_1d(PyArrayObject* working, npy_intp size, PyObject* slices)
{
npy_intp start = 0;
const char* base = PyArray_BYTES(working);
npy_intp stride = PyArray_STRIDE(working, 0);
npy_intp itemsize = PyArray_ITEMSIZE(working);

if (stride == itemsize) {
switch (PyArray_TYPE(working)) {
Comment thread
flexatone marked this conversation as resolved.
Outdated
case NPY_DOUBLE: {
// == honors IEEE semantics: NaN != NaN, +0.0 == -0.0
const double* v = (const double*)base;
AK_SCAN_TRANSITIONS(v[i] == v[i - 1]);
goto finalize;
}
case NPY_FLOAT: {
const float* v = (const float*)base;
AK_SCAN_TRANSITIONS(v[i] == v[i - 1]);
goto finalize;
}
case NPY_DATETIME:
case NPY_TIMEDELTA: {
// NaT (INT64_MIN) is unequal to everything, including itself
const npy_int64* v = (const npy_int64*)base;
AK_SCAN_TRANSITIONS(v[i - 1] != NPY_DATETIME_NAT
&& v[i] != NPY_DATETIME_NAT
&& v[i] == v[i - 1]);
goto finalize;
}
case NPY_BOOL:
case NPY_BYTE: case NPY_UBYTE:
case NPY_SHORT: case NPY_USHORT:
case NPY_INT: case NPY_UINT:
case NPY_LONG: case NPY_ULONG:
case NPY_LONGLONG: case NPY_ULONGLONG: {
// Integral dtypes: `!=` is raw-bit inequality, so compare by width.
switch (itemsize) {
case 1: { const npy_uint8* v = (const npy_uint8*)base; AK_SCAN_TRANSITIONS(v[i] == v[i - 1]); goto finalize; }
case 2: { const npy_uint16* v = (const npy_uint16*)base; AK_SCAN_TRANSITIONS(v[i] == v[i - 1]); goto finalize; }
case 4: { const npy_uint32* v = (const npy_uint32*)base; AK_SCAN_TRANSITIONS(v[i] == v[i - 1]); goto finalize; }
case 8: { const npy_uint64* v = (const npy_uint64*)base; AK_SCAN_TRANSITIONS(v[i] == v[i - 1]); goto finalize; }
default: break;
}
break;
}
case NPY_OBJECT: {
// Read the stored PyObject* directly (borrowed, GIL held) and
// rich-compare, skipping PyArray_GETITEM dispatch and refcount
// churn. The compare can error, so this can't use the macro.
PyObject** v = (PyObject**)base;
npy_intp i = 1;
while (i < size) {
while (i < size) {
int eq = PyObject_RichCompareBool(v[i], v[i - 1], Py_EQ);
if (eq < 0) {
return -1;
}
if (!eq) {
break;
}
++i;
}
if (i >= size) {
break;
}
if (AK_append_transition_slice(slices, start, i)) {
return -1;
}
start = i;
++i;
}
goto finalize;
}
default:
break;
}
}

// Generic fallback.
for (npy_intp i = 1; i < size; ++i) {
int is_transition = AK_values_not_equal_1d(working, i - 1, i);
if (is_transition < 0) {
return -1;
}
if (is_transition) {
if (AK_append_transition_slice(slices, start, i)) {
return -1;
}
start = i;
}
}

finalize:
if (start < size) {
if (AK_append_transition_slice(slices, start, -1)) {
return -1;
}
}
return 0;
}
#undef AK_SCAN_TRANSITIONS

PyObject *
transition_slices_from_group(PyObject *Py_UNUSED(m), PyObject *a)
{
AK_CHECK_NUMPY_ARRAY_1D_2D(a);
PyArrayObject* group = (PyArrayObject*)a;
bool group_to_tuple = PyArray_NDIM(group) == 2;
npy_intp size = PyArray_DIM(group, 0);

PyArrayObject* working = group;
Py_INCREF(working);

PyObject* slices = PyList_New(0);
if (!slices) {
Py_DECREF(working);
return NULL;
}

if (!group_to_tuple) {
if (AK_fill_transition_slices_1d(working, size, slices)) {
Py_DECREF(working);
Py_DECREF(slices);
return NULL;
}
}
else {
npy_intp start = 0;
for (npy_intp i = 1; i < size; ++i) {
int is_equal = AK_rows_equal_2d(working, i - 1, i);
if (is_equal < 0) {
Py_DECREF(working);
Py_DECREF(slices);
return NULL;
}
if (!is_equal) {
if (AK_append_transition_slice(slices, start, i)) {
Py_DECREF(working);
Py_DECREF(slices);
return NULL;
}
start = i;
}
}
if (start < size) {
if (AK_append_transition_slice(slices, start, -1)) {
Py_DECREF(working);
Py_DECREF(slices);
return NULL;
}
}
}
Py_DECREF(working);

PyObject* slices_iter = PyObject_GetIter(slices);
Py_DECREF(slices);
if (!slices_iter) {
return NULL;
}
PyObject* result = PyTuple_Pack(2, slices_iter, group_to_tuple ? Py_True : Py_False);
Py_DECREF(slices_iter);
return result;
}

PyObject*
is_objectable_dt64(PyObject *m, PyObject *a) {
AK_CHECK_NUMPY_ARRAY(a);
Expand Down
3 changes: 3 additions & 0 deletions src/methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ resolve_dtype_iter(PyObject *Py_UNUSED(m), PyObject *arg);
PyObject *
nonzero_1d(PyObject *Py_UNUSED(m), PyObject *a);

PyObject *
transition_slices_from_group(PyObject *Py_UNUSED(m), PyObject *a);

PyObject *
is_objectable_dt64(PyObject *m, PyObject *a);

Expand Down
Loading
Loading