Skip to content

Commit fe1d8ad

Browse files
kesmit13claude
andcommitted
Fix recv_exact protocol desync and unchecked PyObject_Length returns
Guard against protocol desynchronization when poll() times out after partial data has been consumed from the socket. In the C path (accel_recv_exact), switch to blocking mode when pos > 0 so the message is always completed. Apply the same fix to the Python fallback (_recv_exact_py) by catching TimeoutError mid-read and removing the socket timeout. Add error checking at all PyObject_Length call sites that cast the result to unsigned. PyObject_Length returns -1 on error, which when cast to unsigned long long produces ULLONG_MAX, leading to massive malloc allocations or out-of-bounds access. Each site now checks for < 0 and gotos error before casting. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 34e2452 commit fe1d8ad

File tree

2 files changed

+74
-20
lines changed

2 files changed

+74
-20
lines changed

accel.c

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,7 +2276,11 @@ static PyObject *load_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
22762276
orig_data = data;
22772277

22782278
// Get number of columns
2279-
n_cols = PyObject_Length(py_colspec);
2279+
{
2280+
Py_ssize_t tmp = PyObject_Length(py_colspec);
2281+
if (tmp < 0) goto error;
2282+
n_cols = (unsigned long long)tmp;
2283+
}
22802284

22812285
// Determine column types
22822286
ctypes = calloc(sizeof(int), n_cols);
@@ -2920,19 +2924,27 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
29202924
goto error;
29212925
}
29222926

2923-
if (PyObject_Length(py_returns) != PyObject_Length(py_cols)) {
2924-
PyErr_SetString(PyExc_ValueError, "number of return values does not match number of returned columns");
2925-
goto error;
2927+
{
2928+
Py_ssize_t tmp_returns_l = PyObject_Length(py_returns);
2929+
if (tmp_returns_l < 0) goto error;
2930+
Py_ssize_t tmp_cols_l = PyObject_Length(py_cols);
2931+
if (tmp_cols_l < 0) goto error;
2932+
if (tmp_returns_l != tmp_cols_l) {
2933+
PyErr_SetString(PyExc_ValueError, "number of return values does not match number of returned columns");
2934+
goto error;
2935+
}
2936+
n_cols = (unsigned long long)tmp_returns_l;
29262937
}
29272938

2928-
n_rows = (unsigned long long)PyObject_Length(py_row_ids);
2939+
{
2940+
Py_ssize_t tmp = PyObject_Length(py_row_ids);
2941+
if (tmp < 0) goto error;
2942+
n_rows = (unsigned long long)tmp;
2943+
}
29292944
if (n_rows == 0) {
29302945
py_out = PyBytes_FromStringAndSize("", 0);
29312946
goto exit;
29322947
}
2933-
2934-
// Verify all data lengths agree
2935-
n_cols = (unsigned long long)PyObject_Length(py_returns);
29362948
if (n_cols == 0) {
29372949
py_out = PyBytes_FromStringAndSize("", 0);
29382950
goto exit;
@@ -2944,17 +2956,25 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
29442956
PyObject *py_data = PyTuple_GetItem(py_item, 0);
29452957
if (!py_data) goto error;
29462958

2947-
if ((unsigned long long)PyObject_Length(py_data) != n_rows) {
2948-
PyErr_SetString(PyExc_ValueError, "mismatched lengths of column values");
2949-
goto error;
2959+
{
2960+
Py_ssize_t tmp = PyObject_Length(py_data);
2961+
if (tmp < 0) goto error;
2962+
if ((unsigned long long)tmp != n_rows) {
2963+
PyErr_SetString(PyExc_ValueError, "mismatched lengths of column values");
2964+
goto error;
2965+
}
29502966
}
29512967

29522968
PyObject *py_mask = PyTuple_GetItem(py_item, 1);
29532969
if (!py_mask) goto error;
29542970

2955-
if (py_mask != Py_None && (unsigned long long)PyObject_Length(py_mask) != n_rows) {
2956-
PyErr_SetString(PyExc_ValueError, "length of mask values does not match the length of data rows");
2957-
goto error;
2971+
if (py_mask != Py_None) {
2972+
Py_ssize_t tmp = PyObject_Length(py_mask);
2973+
if (tmp < 0) goto error;
2974+
if ((unsigned long long)tmp != n_rows) {
2975+
PyErr_SetString(PyExc_ValueError, "length of mask values does not match the length of data rows");
2976+
goto error;
2977+
}
29582978
}
29592979
}
29602980

@@ -4179,7 +4199,11 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
41794199
CHECKRC(PyBytes_AsStringAndSize(py_data, &data, &length));
41804200
end = data + (unsigned long long)length;
41814201

4182-
colspec_l = PyObject_Length(py_colspec);
4202+
{
4203+
Py_ssize_t tmp = PyObject_Length(py_colspec);
4204+
if (tmp < 0) goto error;
4205+
colspec_l = (unsigned long long)tmp;
4206+
}
41834207
ctypes = malloc(sizeof(int) * colspec_l);
41844208

41854209
for (i = 0; i < colspec_l; i++) {
@@ -4481,7 +4505,11 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
44814505
goto error;
44824506
}
44834507

4484-
n_rows = (unsigned long long)PyObject_Length(py_rows);
4508+
{
4509+
Py_ssize_t tmp = PyObject_Length(py_rows);
4510+
if (tmp < 0) goto error;
4511+
n_rows = (unsigned long long)tmp;
4512+
}
44854513
if (n_rows == 0) {
44864514
py_out = PyBytes_FromStringAndSize("", 0);
44874515
goto exit;
@@ -4494,7 +4522,11 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
44944522
if (!out) goto error;
44954523

44964524
// Get return types
4497-
n_cols = (unsigned long long)PyObject_Length(py_returns);
4525+
{
4526+
Py_ssize_t tmp = PyObject_Length(py_returns);
4527+
if (tmp < 0) goto error;
4528+
n_cols = (unsigned long long)tmp;
4529+
}
44984530
if (n_cols == 0) {
44994531
PyErr_SetString(PyExc_ValueError, "no return values specified");
45004532
goto error;
@@ -4809,7 +4841,11 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k
48094841
if (length == 0) { py_out = PyBytes_FromStringAndSize("", 0); goto exit; }
48104842

48114843
// Parse colspec types
4812-
colspec_l = (unsigned long long)PyObject_Length(py_colspec);
4844+
{
4845+
Py_ssize_t tmp = PyObject_Length(py_colspec);
4846+
if (tmp < 0) goto error;
4847+
colspec_l = (unsigned long long)tmp;
4848+
}
48134849
ctypes = malloc(sizeof(int) * colspec_l);
48144850
if (!ctypes) goto error;
48154851
for (i = 0; i < colspec_l; i++) {
@@ -4822,7 +4858,11 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k
48224858
}
48234859

48244860
// Parse return types
4825-
returns_l = (unsigned long long)PyObject_Length(py_returns);
4861+
{
4862+
Py_ssize_t tmp = PyObject_Length(py_returns);
4863+
if (tmp < 0) goto error;
4864+
returns_l = (unsigned long long)tmp;
4865+
}
48264866
rtypes = malloc(sizeof(int) * returns_l);
48274867
if (!rtypes) goto error;
48284868
for (i = 0; i < returns_l; i++) {
@@ -5529,6 +5569,12 @@ static PyObject *accel_recv_exact(PyObject *self, PyObject *args) {
55295569
poll_rc = poll(&pfd, 1, timeout_ms);
55305570
Py_END_ALLOW_THREADS
55315571
if (poll_rc == 0) {
5572+
if (pos > 0) {
5573+
/* Partial message already consumed — must finish it.
5574+
Block indefinitely to avoid protocol desync. */
5575+
timeout_ms = -1;
5576+
continue;
5577+
}
55325578
free(buf);
55335579
PyErr_SetString(PyExc_TimeoutError, "recv_exact timed out");
55345580
return NULL;

singlestoredb/functions/ext/collocated/connection.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,15 @@ def _recv_exact_py(sock: socket.socket, n: int) -> bytes | None:
395395
view = memoryview(buf)
396396
pos = 0
397397
while pos < n:
398-
nbytes = sock.recv_into(view[pos:])
398+
try:
399+
nbytes = sock.recv_into(view[pos:])
400+
except TimeoutError:
401+
if pos == 0:
402+
raise
403+
# Partial message already consumed — must finish it.
404+
# Remove timeout to avoid protocol desync.
405+
sock.settimeout(None)
406+
continue
399407
if nbytes == 0:
400408
return None
401409
pos += nbytes

0 commit comments

Comments
 (0)