Skip to content

Commit ed6edfc

Browse files
authored
Merge pull request #211 from static-frame/copilot/transition-slices-from-group
Implement `transition_slices_from_group` in C extension and expose public API
2 parents af0992e + 82e233a commit ed6edfc

7 files changed

Lines changed: 570 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,3 @@ skip-magic-trailing-comma = false
7474
line-ending = "auto"
7575
docstring-code-format = true
7676
docstring-code-line-length = "dynamic"
77-
78-

src/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ._arraykit import array_to_tuple_array as array_to_tuple_array
3434
from ._arraykit import array_to_tuple_iter as array_to_tuple_iter
3535
from ._arraykit import nonzero_1d as nonzero_1d
36+
from ._arraykit import transition_slices_from_group as transition_slices_from_group
3637
from ._arraykit import is_objectable_dt64 as is_objectable_dt64
3738
from ._arraykit import is_objectable as is_objectable
3839
from ._arraykit import astype_array as astype_array

src/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def write_array_to_file(
227227
def first_true_1d(__array: np.ndarray, *, forward: bool) -> int: ...
228228
def first_true_2d(__array: np.ndarray, *, forward: bool, axis: int) -> np.ndarray: ...
229229
def nonzero_1d(__array: np.ndarray, /) -> np.ndarray: ...
230+
def transition_slices_from_group(
231+
__group: np.ndarray, /
232+
) -> tp.Tuple[tp.Iterator[slice], bool]: ...
230233
def is_objectable_dt64(__array: np.ndarray, /) -> bool: ...
231234
def is_objectable(__array: np.ndarray, /) -> bool: ...
232235
def astype_array(__array: np.ndarray, __dtype: np.dtype | None, /) -> np.ndarray: ...

src/_arraykit.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ static PyMethodDef arraykit_methods[] = {
5353
NULL},
5454
{"count_iteration", count_iteration, METH_O, NULL},
5555
{"nonzero_1d", nonzero_1d, METH_O, NULL},
56+
{"transition_slices_from_group", transition_slices_from_group, METH_O, NULL},
5657
{"is_objectable_dt64", is_objectable_dt64, METH_O, NULL},
5758
{"is_objectable", is_objectable, METH_O, NULL},
5859
{"astype_array", astype_array, METH_VARARGS, NULL},

src/methods.c

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,344 @@ nonzero_1d(PyObject *Py_UNUSED(m), PyObject *a) {
253253
return AK_nonzero_1d(array);
254254
}
255255

256+
//------------------------------------------------------------------------------
257+
// transition_slices_from_group
258+
259+
// Returns -1 on error
260+
static inline int
261+
AK_append_transition_slice(PyObject* slices, npy_intp start, npy_intp stop)
262+
{
263+
PyObject* py_start = PyLong_FromSsize_t(start);
264+
if (!py_start) {
265+
return -1;
266+
}
267+
268+
PyObject* py_stop = NULL;
269+
if (stop == -1) {
270+
py_stop = Py_None;
271+
Py_INCREF(py_stop);
272+
}
273+
else {
274+
py_stop = PyLong_FromSsize_t(stop);
275+
}
276+
if (!py_stop) {
277+
Py_DECREF(py_start);
278+
return -1;
279+
}
280+
281+
PyObject* slc = PySlice_New(py_start, py_stop, Py_None);
282+
Py_DECREF(py_start);
283+
Py_DECREF(py_stop);
284+
if (!slc) {
285+
return -1;
286+
}
287+
int append_result = PyList_Append(slices, slc); // -1 on error
288+
Py_DECREF(slc);
289+
return append_result;
290+
}
291+
292+
// Return 1 if the two elements are not equal (a transition), 0 if equal, -1 on
293+
// error. Equality matches numpy's elementwise `!=`: byte-exact for integral,
294+
// string, and void dtypes; IEEE-aware for floats (NaN != NaN, +0.0 == -0.0);
295+
// and NaT-aware for datetime64/timedelta64 (NaT != everything, including NaT).
296+
// Uncommon dtypes (complex, float16, longdouble, object) fall back to a boxed
297+
// Python comparison, which is correct though slower.
298+
static inline int
299+
AK_values_not_equal_1d(PyArrayObject* array, npy_intp left, npy_intp right)
300+
{
301+
npy_intp stride = PyArray_STRIDE(array, 0);
302+
char* p_left = PyArray_BYTES(array) + (left * stride);
303+
char* p_right = PyArray_BYTES(array) + (right * stride);
304+
305+
switch (PyArray_TYPE(array)) {
306+
// For these dtypes numpy `!=` is exactly a byte comparison, so memcmp
307+
// is both correct and the fastest option (no Python objects created).
308+
case NPY_BOOL:
309+
case NPY_BYTE: case NPY_UBYTE:
310+
case NPY_SHORT: case NPY_USHORT:
311+
case NPY_INT: case NPY_UINT:
312+
case NPY_LONG: case NPY_ULONG:
313+
case NPY_LONGLONG: case NPY_ULONGLONG:
314+
case NPY_STRING: case NPY_UNICODE: case NPY_VOID:
315+
return memcmp(p_left, p_right, (size_t)PyArray_ITEMSIZE(array)) != 0;
316+
317+
// Floats need IEEE semantics, which a byte compare would get wrong for
318+
// NaN (distinct bits compare equal under ==) and signed zero.
319+
case NPY_FLOAT: {
320+
npy_float a, b;
321+
memcpy(&a, p_left, sizeof a);
322+
memcpy(&b, p_right, sizeof b);
323+
return a != b;
324+
}
325+
case NPY_DOUBLE: {
326+
npy_double a, b;
327+
memcpy(&a, p_left, sizeof a);
328+
memcpy(&b, p_right, sizeof b);
329+
return a != b;
330+
}
331+
332+
// datetime64/timedelta64 are int64 with NaT == INT64_MIN;
333+
case NPY_DATETIME:
334+
case NPY_TIMEDELTA: {
335+
npy_int64 a, b;
336+
memcpy(&a, p_left, sizeof a);
337+
memcpy(&b, p_right, sizeof b);
338+
if (a == NPY_DATETIME_NAT || b == NPY_DATETIME_NAT) {
339+
return 1;
340+
}
341+
return a != b;
342+
}
343+
}
344+
345+
// Object dtype and any unhandled type: compare via Python objects.
346+
PyObject* left_value = PyArray_GETITEM(array, p_left);
347+
if (!left_value) {
348+
return -1;
349+
}
350+
PyObject* right_value = PyArray_GETITEM(array, p_right);
351+
if (!right_value) {
352+
Py_DECREF(left_value);
353+
return -1;
354+
}
355+
int is_not_equal = PyObject_RichCompareBool(left_value, right_value, Py_NE);
356+
Py_DECREF(left_value);
357+
Py_DECREF(right_value);
358+
return is_not_equal;
359+
}
360+
361+
// Compare two rows of a 2d array. Object arrays are compared cell-by-cell with
362+
// rich equality, consistent with the 1d object path (AK_values_not_equal_1d);
363+
// all other dtypes are compared bytewise, matching numpy's void-view equality.
364+
static inline int
365+
AK_rows_equal_2d(PyArrayObject* array, npy_intp left, npy_intp right)
366+
{
367+
if (PyArray_TYPE(array) == NPY_OBJECT) {
368+
npy_intp cols = PyArray_DIM(array, 1);
369+
for (npy_intp col = 0; col < cols; ++col) {
370+
PyObject* left_value = *(PyObject**)PyArray_GETPTR2(array, left, col);
371+
PyObject* right_value = *(PyObject**)PyArray_GETPTR2(array, right, col);
372+
int is_equal = PyObject_RichCompareBool(left_value, right_value, Py_EQ);
373+
if (is_equal <= 0) {
374+
return is_equal; // 0 (not equal) or -1 (error)
375+
}
376+
}
377+
return 1;
378+
}
379+
380+
npy_intp stride_0 = PyArray_STRIDE(array, 0);
381+
npy_intp stride_1 = PyArray_STRIDE(array, 1);
382+
npy_intp cols = PyArray_DIM(array, 1);
383+
npy_intp itemsize = PyArray_ITEMSIZE(array);
384+
385+
char* p_left = PyArray_BYTES(array) + (left * stride_0);
386+
char* p_right = PyArray_BYTES(array) + (right * stride_0);
387+
388+
if (stride_1 == itemsize) {
389+
return memcmp(p_left, p_right, (size_t)(cols * itemsize)) == 0;
390+
}
391+
392+
for (npy_intp col = 0; col < cols; ++col) {
393+
npy_intp offset = col * stride_1;
394+
if (memcmp(p_left + offset, p_right + offset, (size_t)itemsize) != 0) {
395+
return 0;
396+
}
397+
}
398+
return 1;
399+
}
400+
401+
// Scan a contiguous, monomorphic 1d buffer for transitions. EQ tests whether
402+
// element i equals element i-1; the inner loop runs over equal runs with no
403+
// function call so the compiler can keep it tight (and vectorize the compare).
404+
// On a transition, emit the run [start, i) and continue. `goto finalize` after.
405+
#define AK_SCAN_TRANSITIONS(EQ) \
406+
do { \
407+
npy_intp i = 1; \
408+
while (i < size) { \
409+
while (i < size && (EQ)) { ++i; } \
410+
if (i >= size) { break; } \
411+
if (AK_append_transition_slice(slices, start, i)) { return -1; } \
412+
start = i; \
413+
++i; \
414+
} \
415+
} while (0)
416+
417+
// Append transition slices for a 1d array to `slices`. Returns 0 on success,
418+
// -1 on error. Common contiguous dtypes use a hoisted, typed scan; everything
419+
// else (non-contiguous, strings, void, complex, half, longdouble, object, or
420+
// unusual widths) falls back to the per-element comparison.
421+
static int
422+
AK_fill_transition_slices_1d(PyArrayObject* working, npy_intp size, PyObject* slices)
423+
{
424+
npy_intp start = 0;
425+
const char* base = PyArray_BYTES(working);
426+
npy_intp stride = PyArray_STRIDE(working, 0);
427+
npy_intp itemsize = PyArray_ITEMSIZE(working);
428+
429+
// stride == itemsize proves the data is packed, but the typed fast-path
430+
// below casts `base` to a concrete C type and dereferences v[i] directly,
431+
// which is undefined behavior on a misaligned buffer (and can SIGBUS on
432+
// strict-alignment platforms). Contiguity and alignment are independent in
433+
// numpy, so require PyArray_ISALIGNED explicitly; anything unaligned falls
434+
// through to the memcpy/memcmp-based generic scan, which is alignment-safe.
435+
if (stride == itemsize && PyArray_ISALIGNED(working)) {
436+
switch (PyArray_TYPE(working)) {
437+
case NPY_DOUBLE: {
438+
// == honors IEEE semantics: NaN != NaN, +0.0 == -0.0
439+
const npy_double* v = (const npy_double*)base;
440+
AK_SCAN_TRANSITIONS(v[i] == v[i - 1]);
441+
goto finalize;
442+
}
443+
case NPY_FLOAT: {
444+
const npy_float* v = (const npy_float*)base;
445+
AK_SCAN_TRANSITIONS(v[i] == v[i - 1]);
446+
goto finalize;
447+
}
448+
case NPY_DATETIME:
449+
case NPY_TIMEDELTA: {
450+
// NaT (INT64_MIN) is unequal to everything, including itself
451+
const npy_int64* v = (const npy_int64*)base;
452+
AK_SCAN_TRANSITIONS(v[i - 1] != NPY_DATETIME_NAT
453+
&& v[i] != NPY_DATETIME_NAT
454+
&& v[i] == v[i - 1]);
455+
goto finalize;
456+
}
457+
case NPY_BOOL:
458+
case NPY_BYTE: case NPY_UBYTE:
459+
case NPY_SHORT: case NPY_USHORT:
460+
case NPY_INT: case NPY_UINT:
461+
case NPY_LONG: case NPY_ULONG:
462+
case NPY_LONGLONG: case NPY_ULONGLONG: {
463+
// Integral dtypes: `!=` is raw-bit inequality, so compare by width.
464+
switch (itemsize) {
465+
case 1: { const npy_uint8* v = (const npy_uint8*)base; AK_SCAN_TRANSITIONS(v[i] == v[i - 1]); goto finalize; }
466+
case 2: { const npy_uint16* v = (const npy_uint16*)base; AK_SCAN_TRANSITIONS(v[i] == v[i - 1]); goto finalize; }
467+
case 4: { const npy_uint32* v = (const npy_uint32*)base; AK_SCAN_TRANSITIONS(v[i] == v[i - 1]); goto finalize; }
468+
case 8: { const npy_uint64* v = (const npy_uint64*)base; AK_SCAN_TRANSITIONS(v[i] == v[i - 1]); goto finalize; }
469+
default: break;
470+
}
471+
break;
472+
}
473+
case NPY_OBJECT: {
474+
// Read the stored PyObject* directly (borrowed, GIL held) and
475+
// rich-compare, skipping PyArray_GETITEM dispatch and refcount
476+
// churn. The compare can error, so this can't use the macro.
477+
PyObject** v = (PyObject**)base;
478+
npy_intp i = 1;
479+
while (i < size) {
480+
while (i < size) {
481+
int eq = PyObject_RichCompareBool(v[i], v[i - 1], Py_EQ);
482+
if (eq < 0) {
483+
return -1;
484+
}
485+
if (!eq) {
486+
break;
487+
}
488+
++i;
489+
}
490+
if (i >= size) {
491+
break;
492+
}
493+
if (AK_append_transition_slice(slices, start, i)) {
494+
return -1;
495+
}
496+
start = i;
497+
++i;
498+
}
499+
goto finalize;
500+
}
501+
default:
502+
break;
503+
}
504+
}
505+
506+
// Generic fallback.
507+
for (npy_intp i = 1; i < size; ++i) {
508+
int is_transition = AK_values_not_equal_1d(working, i - 1, i);
509+
if (is_transition < 0) {
510+
return -1;
511+
}
512+
if (is_transition) {
513+
if (AK_append_transition_slice(slices, start, i)) {
514+
return -1;
515+
}
516+
start = i;
517+
}
518+
}
519+
520+
finalize:
521+
if (start < size) {
522+
if (AK_append_transition_slice(slices, start, -1)) {
523+
return -1;
524+
}
525+
}
526+
return 0;
527+
}
528+
#undef AK_SCAN_TRANSITIONS
529+
530+
PyObject *
531+
transition_slices_from_group(PyObject *Py_UNUSED(m), PyObject *a)
532+
{
533+
AK_CHECK_NUMPY_ARRAY_1D_2D(a);
534+
PyArrayObject* group = (PyArrayObject*)a;
535+
bool group_to_tuple = PyArray_NDIM(group) == 2;
536+
npy_intp size = PyArray_DIM(group, 0);
537+
538+
PyArrayObject* working = group;
539+
Py_INCREF(working);
540+
541+
PyObject* slices = PyList_New(0);
542+
if (!slices) {
543+
Py_DECREF(working);
544+
return NULL;
545+
}
546+
547+
if (!group_to_tuple) {
548+
if (AK_fill_transition_slices_1d(working, size, slices)) {
549+
Py_DECREF(working);
550+
Py_DECREF(slices);
551+
return NULL;
552+
}
553+
}
554+
else {
555+
npy_intp start = 0;
556+
for (npy_intp i = 1; i < size; ++i) {
557+
int is_equal = AK_rows_equal_2d(working, i - 1, i);
558+
if (is_equal < 0) {
559+
Py_DECREF(working);
560+
Py_DECREF(slices);
561+
return NULL;
562+
}
563+
if (!is_equal) {
564+
if (AK_append_transition_slice(slices, start, i)) {
565+
Py_DECREF(working);
566+
Py_DECREF(slices);
567+
return NULL;
568+
}
569+
start = i;
570+
}
571+
}
572+
if (start < size) {
573+
if (AK_append_transition_slice(slices, start, -1)) {
574+
Py_DECREF(working);
575+
Py_DECREF(slices);
576+
return NULL;
577+
}
578+
}
579+
}
580+
Py_DECREF(working);
581+
582+
PyObject* slices_iter = PyObject_GetIter(slices);
583+
Py_DECREF(slices);
584+
if (!slices_iter) {
585+
return NULL;
586+
}
587+
PyObject* result = PyTuple_Pack(2, slices_iter, group_to_tuple ? Py_True : Py_False);
588+
Py_DECREF(slices_iter);
589+
return result;
590+
}
591+
592+
593+
//------------------------------------------------------------------------------
256594
PyObject*
257595
is_objectable_dt64(PyObject *m, PyObject *a) {
258596
AK_CHECK_NUMPY_ARRAY(a);

src/methods.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ resolve_dtype_iter(PyObject *Py_UNUSED(m), PyObject *arg);
5151
PyObject *
5252
nonzero_1d(PyObject *Py_UNUSED(m), PyObject *a);
5353

54+
PyObject *
55+
transition_slices_from_group(PyObject *Py_UNUSED(m), PyObject *a);
56+
5457
PyObject *
5558
is_objectable_dt64(PyObject *m, PyObject *a);
5659

0 commit comments

Comments
 (0)