Skip to content

Commit 88eef29

Browse files
committed
Switch to new numba string type
1 parent 9a705ff commit 88eef29

5 files changed

Lines changed: 73 additions & 108 deletions

File tree

python/_tskitmodule.c

Lines changed: 44 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
*/
2525

2626
#define PY_SSIZE_T_CLEAN
27-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
27+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
28+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
2829
#define TSK_BUG_ASSERT_MESSAGE \
2930
"Please open an issue on" \
3031
" GitHub, ideally with a reproducible example." \
@@ -10779,110 +10780,56 @@ TreeSequence_decode_ragged_string_column(
1077910780
TreeSequence *self, tsk_size_t num_rows, const char *data, const tsk_size_t *offset)
1078010781
{
1078110782
PyObject *ret = NULL;
10782-
PyArrayObject *ret_array = NULL;
10783+
PyObject *array = NULL;
10784+
PyArray_StringDTypeObject *string_dtype = NULL;
10785+
npy_string_allocator *allocator = NULL;
10786+
char *array_data = NULL;
1078310787
npy_intp dims[1];
10784-
char *string_data = NULL;
1078510788
tsk_size_t i;
10786-
tsk_size_t max_length;
10787-
tsk_size_t len;
10788-
tsk_size_t start;
10789-
tsk_size_t end;
10790-
char *dest;
10791-
bool all_length_one;
10789+
int pack_result;
1079210790

10791+
string_dtype = (PyArray_StringDTypeObject *) PyArray_DescrFromType(NPY_VSTRING);
10792+
if (string_dtype == NULL) {
10793+
goto out;
10794+
}
1079310795
dims[0] = (npy_intp) num_rows;
10794-
10795-
/* If all strings are length one we can back the array with tskit memory */
10796-
all_length_one = true;
10797-
for (i = 0; i < num_rows; i++) {
10798-
len = offset[i + 1] - offset[i];
10799-
if (len != 1) {
10800-
all_length_one = false;
10801-
break;
10802-
}
10796+
array = PyArray_Zeros(1, dims, (PyArray_Descr *) string_dtype, 0);
10797+
if (array == NULL) {
10798+
goto out;
10799+
}
10800+
allocator = NpyString_acquire_allocator(string_dtype);
10801+
if (allocator == NULL) {
10802+
goto out;
1080310803
}
10804+
array_data = (char *) PyArray_DATA((PyArrayObject *) array);
10805+
for (i = 0; i < num_rows; i++) {
1080410806

10805-
if (all_length_one && num_rows > 0) {
10806-
ret_array = (PyArrayObject *) PyArray_New(&PyArray_Type, /* subtype */
10807-
1, /* nd */
10808-
dims, /* dims */
10809-
NPY_STRING, /* type_num */
10810-
NULL, /* strides */
10811-
(void *) data, /* data */
10812-
1, /* itemsize = 1 for S1 */
10813-
0, /* flags */
10814-
NULL /* obj */
10815-
);
10816-
if (ret_array == NULL) {
10817-
goto out;
10818-
}
10807+
pack_result = NpyString_pack(allocator,
10808+
(npy_packed_static_string
10809+
*) (array_data + (i * ((PyArray_Descr *) string_dtype)->elsize)),
10810+
data + offset[i], offset[i + 1] - offset[i]);
1081910811

10820-
if (PyArray_SetBaseObject(ret_array, (PyObject *) self) != 0) {
10812+
if (pack_result == -1) {
1082110813
goto out;
1082210814
}
10823-
Py_INCREF(self);
10824-
10825-
} else {
10826-
/* We have to pad the strings to make a rectangular array */
10827-
max_length = 0;
10828-
for (i = 0; i < num_rows; i++) {
10829-
len = offset[i + 1] - offset[i];
10830-
if (len > max_length) {
10831-
max_length = len;
10832-
}
10833-
}
10834-
10835-
/* Ensure at least S1 */
10836-
if (max_length == 0) {
10837-
max_length = 1;
10838-
}
10839-
10840-
if (num_rows > 0) {
10841-
tsk_size_t total_size = num_rows * max_length;
10842-
string_data = PyMem_Calloc(total_size, sizeof(char));
10843-
if (string_data == NULL) {
10844-
PyErr_NoMemory();
10845-
goto out;
10846-
}
10847-
10848-
for (i = 0; i < num_rows; i++) {
10849-
start = offset[i];
10850-
end = offset[i + 1];
10851-
len = end - start;
10852-
10853-
dest = string_data + (i * max_length);
10854-
if (len > 0) {
10855-
memcpy(dest, data + start, len);
10856-
}
10857-
/* PyMem_Calloc zero-fills so we don't need to pad */
10858-
}
10859-
}
10815+
}
1086010816

10861-
ret_array = (PyArrayObject *) PyArray_New(&PyArray_Type, /* subtype */
10862-
1, /* nd */
10863-
dims, /* dims */
10864-
NPY_STRING, /* type_num */
10865-
NULL, /* strides */
10866-
string_data, /* data */
10867-
(int) max_length, /* itemsize */
10868-
NPY_ARRAY_OWNDATA, /* flags - we own the data */
10869-
NULL /* obj */
10870-
);
10871-
if (ret_array == NULL) {
10872-
goto out;
10873-
}
10874-
string_data = NULL; /* Array now owns the memory */
10817+
if (PyArray_SetBaseObject((PyArrayObject *) array, (PyObject *) self) != 0) {
10818+
goto out;
1087510819
}
10820+
Py_INCREF(self);
1087610821

10877-
/* Make array read-only for consistent semantics with tskit memory case */
10878-
PyArray_CLEARFLAGS(ret_array, NPY_ARRAY_WRITEABLE);
10822+
/* Clear the writeable flag to match other arrays semantics */
10823+
PyArray_CLEARFLAGS((PyArrayObject *) array, NPY_ARRAY_WRITEABLE);
1087910824

10880-
ret = (PyObject *) ret_array;
10881-
ret_array = NULL;
10825+
ret = array;
10826+
array = NULL;
1088210827

1088310828
out:
10884-
PyMem_Free(string_data);
10885-
Py_XDECREF(ret_array);
10829+
if (allocator != NULL) {
10830+
NpyString_release_allocator(allocator);
10831+
}
10832+
Py_XDECREF(array);
1088610833
return ret;
1088710834
}
1088810835

@@ -14705,11 +14652,16 @@ static struct PyModuleDef tskitmodule = {
1470514652
PyObject *
1470614653
PyInit__tskit(void)
1470714654
{
14708-
PyObject *module = PyModule_Create(&tskitmodule);
14655+
PyObject *module;
14656+
14657+
if (PyArray_ImportNumPyAPI() < 0) {
14658+
return NULL;
14659+
}
14660+
14661+
module = PyModule_Create(&tskitmodule);
1470914662
if (module == NULL) {
1471014663
return NULL;
1471114664
}
14712-
import_array();
1471314665

1471414666
if (register_lwt_class(module) != 0) {
1471514667
return NULL;

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ keywords = [
4343
requires-python = ">=3.9"
4444
dependencies = [
4545
"jsonschema>=3.0.0",
46-
"numpy>=1.23.5",
46+
"numpy>2",
4747
]
4848

4949
[project.urls]

python/requirements/development.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ msprime>=1.0.0
1717
networkx
1818
newick
1919
ninja
20-
numpy
20+
numpy>=2
2121
packaging
2222
portion
2323
pre-commit

python/requirements/development.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies:
2222
- msprime>=1.0.0
2323
- networkx
2424
- ninja
25-
- numpy<2
25+
- numpy>2
2626
- packaging
2727
- portion
2828
- pre-commit

python/tests/test_lowlevel.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,8 +1916,9 @@ def test_array_read_only(self, name, ts_fixture):
19161916
a[:] = 0
19171917
with pytest.raises(ValueError, match="assignment destination"):
19181918
a[0] = 0
1919-
with pytest.raises(ValueError, match="cannot set WRITEABLE"):
1920-
a.setflags(write=True)
1919+
if name != "sites_ancestral_state":
1920+
with pytest.raises(ValueError, match="cannot set WRITEABLE"):
1921+
a.setflags(write=True)
19211922

19221923
@pytest.mark.parametrize("name", ARRAY_NAMES)
19231924
def test_array_properties(self, name, ts_fixture):
@@ -1927,13 +1928,17 @@ def test_array_properties(self, name, ts_fixture):
19271928
assert not a.flags.writeable
19281929
assert a.flags.aligned
19291930
assert a.flags.c_contiguous
1930-
assert not a.flags.owndata
1931+
if name == "sites_ancestral_state":
1932+
assert a.flags.owndata
1933+
else:
1934+
assert not a.flags.owndata
19311935
b = getattr(ts_fixture, name)
19321936
assert a is not b
19331937
assert np.all(a == b)
19341938
# This checks that the underlying pointer to memory is the same in
19351939
# both arrays.
1936-
assert a.__array_interface__ == b.__array_interface__
1940+
if name != "sites_ancestral_state":
1941+
assert a.__array_interface__ == b.__array_interface__
19371942

19381943
@pytest.mark.parametrize("name", ARRAY_NAMES)
19391944
def test_array_lifetime(self, name, ts_fixture):
@@ -1976,7 +1981,8 @@ def test_individuals_nodes(self, ts_fixture):
19761981
assert a3 is not a2
19771982

19781983
@pytest.mark.parametrize(
1979-
"site_lengths", ["none", "all-0", "all-1", "all-2", "mixed", "very_long"]
1984+
"site_lengths",
1985+
["none", "all-0", "all-1", "all-2", "mixed", "very_long", "unicode"],
19801986
)
19811987
def test_sites_ancestral_state(self, ts_fixture, site_lengths):
19821988
if site_lengths == "none":
@@ -1995,6 +2001,7 @@ def test_sites_ancestral_state(self, ts_fixture, site_lengths):
19952001
"all-2": lambda i, site: "YO",
19962002
"mixed": lambda i, site: "A" * i,
19972003
"very_long": lambda i, site: "A" * 100_000_000 if i == 0 else "A",
2004+
"unicode": lambda i, site: "💩" * (i + 1),
19982005
}
19992006

20002007
get_ancestral_state = ancestral_state_map[site_lengths]
@@ -2003,15 +2010,22 @@ def test_sites_ancestral_state(self, ts_fixture, site_lengths):
20032010
site.replace(ancestral_state=get_ancestral_state(i, site))
20042011
)
20052012
ts = tables.tree_sequence()
2006-
ts = ts.ll_tree_sequence
2013+
ll_ts = ts.ll_tree_sequence
2014+
2015+
a = ll_ts.sites_ancestral_state
2016+
# Contents
2017+
if site_lengths == "none":
2018+
assert a.size == 0
2019+
else:
2020+
for site in ts.sites():
2021+
assert a[site.id] == site.ancestral_state
20072022

20082023
# Read only
20092024
with pytest.raises(AttributeError, match="not writable"):
2010-
ts.sites_ancestral_state = None
2025+
ll_ts.sites_ancestral_state = None
20112026
with pytest.raises(AttributeError, match="not writable"):
2012-
del ts.sites_ancestral_state
2027+
del ll_ts.sites_ancestral_state
20132028

2014-
a = ts.sites_ancestral_state
20152029
with pytest.raises(ValueError, match="assignment destination"):
20162030
a[:] = 0
20172031
with pytest.raises(ValueError, match="assignment destination"):
@@ -2021,18 +2035,17 @@ def test_sites_ancestral_state(self, ts_fixture, site_lengths):
20212035
a.setflags(write=True)
20222036

20232037
# Properties
2024-
a = ts.sites_ancestral_state
20252038
assert a.flags.aligned
20262039
assert a.flags.c_contiguous
2027-
b = ts.sites_ancestral_state
2040+
b = ll_ts.sites_ancestral_state
20282041
assert a is not b
20292042
assert np.all(a == b)
20302043

20312044
# Lifetime
2032-
a1 = ts.sites_ancestral_state
2045+
a1 = ll_ts.sites_ancestral_state
20332046
a2 = a1.copy()
20342047
assert a1 is not a2
2035-
del ts
2048+
del ll_ts
20362049
# Do some memory operations
20372050
a3 = np.ones(10**6)
20382051
assert np.all(a1 == a2)

0 commit comments

Comments
 (0)