Skip to content

Commit 5377083

Browse files
kesmit13claude
andcommitted
Add float16 (half-precision) vector support
This commit adds comprehensive support for float16 (F16) vectors to both MySQL and HTTP clients in the SingleStoreDB Python SDK. Changes: - Add FLOAT16 = 7 as the 7th vector type constant - Add FIELD_TYPE constants: FLOAT16_VECTOR_JSON (2007) and FLOAT16_VECTOR (3007) - Update protocol parser to recognize and handle FLOAT16 vector metadata - Add float16_vector_json_or_none() and float16_vector_or_none() converters - Register float16 converters in the converters dictionary (types 2007 and 3007) - Add FLOAT16 vector types to TEXT_TYPES set for proper type handling - Update C accelerator with full float16 support: - Add FLOAT16 constants and PyStrings struct member - Update type arrays with 'e' format (2 bytes) for struct.unpack - Add float16 to JSON and binary vector case statements - Initialize numpy dtype kwargs for float16 - Add comprehensive tests: - Create f16_vectors test table with 3 test rows - Implement test_f16_vectors() method following existing patterns - Use assert_array_almost_equal with decimal=2 for float16 precision Technical notes: - Float16 has ~3 decimal digits of precision (vs ~7 for float32) - Uses struct format 'e' for half-precision (2 bytes per element) - Supports both JSON and binary wire formats - All pre-commit hooks passed (flake8, autopep8, mypy) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 37f0e64 commit 5377083

File tree

8 files changed

+126
-4
lines changed

8 files changed

+126
-4
lines changed

accel.c

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,14 @@
9191
#define MYSQL_TYPE_INT16_VECTOR_JSON 2004
9292
#define MYSQL_TYPE_INT32_VECTOR_JSON 2005
9393
#define MYSQL_TYPE_INT64_VECTOR_JSON 2006
94+
#define MYSQL_TYPE_FLOAT16_VECTOR_JSON 2007
9495
#define MYSQL_TYPE_FLOAT32_VECTOR 3001
9596
#define MYSQL_TYPE_FLOAT64_VECTOR 3002
9697
#define MYSQL_TYPE_INT8_VECTOR 3003
9798
#define MYSQL_TYPE_INT16_VECTOR 3004
9899
#define MYSQL_TYPE_INT32_VECTOR 3005
99100
#define MYSQL_TYPE_INT64_VECTOR 3006
101+
#define MYSQL_TYPE_FLOAT16_VECTOR 3007
100102

101103
#define MYSQL_TYPE_CHAR MYSQL_TYPE_TINY
102104
#define MYSQL_TYPE_INTERVAL MYSQL_TYPE_ENUM
@@ -503,6 +505,7 @@ typedef struct {
503505
PyObject *int64;
504506
PyObject *float32;
505507
PyObject *float64;
508+
PyObject *float16;
506509
PyObject *unpack;
507510
PyObject *decode;
508511
PyObject *frombuffer;
@@ -541,7 +544,7 @@ typedef struct {
541544
PyObject *namedtuple_kwargs;
542545
PyObject *create_numpy_array_args;
543546
PyObject *create_numpy_array_kwargs;
544-
PyObject *create_numpy_array_kwargs_vector[7];
547+
PyObject *create_numpy_array_kwargs_vector[8];
545548
PyObject *struct_unpack_args;
546549
PyObject *bson_decode_args;
547550
} PyObjects;
@@ -1565,8 +1568,8 @@ static PyObject *read_row_from_packet(
15651568
PyObject *py_str = NULL;
15661569
PyObject *py_memview = NULL;
15671570
char end = '\0';
1568-
char *cast_type_codes[] = {"", "f", "d", "b", "h", "i", "q"};
1569-
int item_type_lengths[] = {0, 4, 8, 1, 2, 4, 8};
1571+
char *cast_type_codes[] = {"", "f", "d", "b", "h", "i", "q", "e"};
1572+
int item_type_lengths[] = {0, 4, 8, 1, 2, 4, 8, 2};
15701573

15711574
int sign = 1;
15721575
int year = 0;
@@ -1826,6 +1829,7 @@ static PyObject *read_row_from_packet(
18261829
case MYSQL_TYPE_INT16_VECTOR_JSON:
18271830
case MYSQL_TYPE_INT32_VECTOR_JSON:
18281831
case MYSQL_TYPE_INT64_VECTOR_JSON:
1832+
case MYSQL_TYPE_FLOAT16_VECTOR_JSON:
18291833
if (!py_state->encodings[i]) {
18301834
py_item = PyBytes_FromStringAndSize(out, out_l);
18311835
if (!py_item) goto error;
@@ -1847,7 +1851,7 @@ static PyObject *read_row_from_packet(
18471851
// Parse JSON string.
18481852
if ((py_state->type_codes[i] == MYSQL_TYPE_JSON && py_state->options.parse_json)
18491853
|| (py_state->type_codes[i] >= MYSQL_TYPE_FLOAT32_VECTOR_JSON
1850-
&& py_state->type_codes[i] <= MYSQL_TYPE_INT64_VECTOR_JSON)) {
1854+
&& py_state->type_codes[i] <= MYSQL_TYPE_FLOAT16_VECTOR_JSON)) {
18511855
py_str = py_item;
18521856
py_item = PyObject_CallFunctionObjArgs(PyFunc.json_loads, py_str, NULL);
18531857
Py_CLEAR(py_str);
@@ -1862,6 +1866,7 @@ static PyObject *read_row_from_packet(
18621866
case MYSQL_TYPE_INT16_VECTOR_JSON:
18631867
case MYSQL_TYPE_INT32_VECTOR_JSON:
18641868
case MYSQL_TYPE_INT64_VECTOR_JSON:
1869+
case MYSQL_TYPE_FLOAT16_VECTOR_JSON:
18651870
CHECKRC(PyTuple_SetItem(PyObj.create_numpy_array_args, 0, py_item));
18661871
py_item = PyObject_Call(
18671872
PyFunc.numpy_array,
@@ -1880,6 +1885,7 @@ static PyObject *read_row_from_packet(
18801885
case MYSQL_TYPE_INT16_VECTOR:
18811886
case MYSQL_TYPE_INT32_VECTOR:
18821887
case MYSQL_TYPE_INT64_VECTOR:
1888+
case MYSQL_TYPE_FLOAT16_VECTOR:
18831889
{
18841890
int type_idx = py_state->type_codes[i] % 1000;
18851891

@@ -4844,6 +4850,7 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
48444850
PyStr.int64 = PyUnicode_FromString("int64");
48454851
PyStr.float32 = PyUnicode_FromString("float32");
48464852
PyStr.float64 = PyUnicode_FromString("float64");
4853+
PyStr.float16 = PyUnicode_FromString("float16");
48474854
PyStr.unpack = PyUnicode_FromString("unpack");
48484855
PyStr.decode = PyUnicode_FromString("decode");
48494856
PyStr.frombuffer = PyUnicode_FromString("frombuffer");
@@ -4921,6 +4928,11 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
49214928
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[6], "dtype", PyStr.int64)) {
49224929
goto error;
49234930
}
4931+
PyObj.create_numpy_array_kwargs_vector[7] = PyDict_New();
4932+
if (!PyObj.create_numpy_array_kwargs_vector[7]) goto error;
4933+
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[7], "dtype", PyStr.float16)) {
4934+
goto error;
4935+
}
49244936

49254937
PyObj.struct_unpack_args = PyTuple_New(2);
49264938
if (!PyObj.struct_unpack_args) goto error;

singlestoredb/converters.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,62 @@ def float32_vector_or_none(x: Optional[bytes]) -> Optional[Any]:
597597
return struct.unpack(f'<{len(x)//4}f', x)
598598

599599

600+
def float16_vector_json_or_none(x: Optional[str]) -> Optional[Any]:
601+
"""
602+
Convert value to float16 array.
603+
604+
Parameters
605+
----------
606+
x : str or None
607+
JSON array
608+
609+
Returns
610+
-------
611+
float16 numpy array
612+
If input value is not None and numpy is installed
613+
float Python list
614+
If input value is not None and numpy is not installed
615+
None
616+
If input value is None
617+
618+
"""
619+
if x is None:
620+
return None
621+
622+
if has_numpy:
623+
return numpy.array(json_loads(x), dtype=numpy.float16)
624+
625+
return map(float, json_loads(x))
626+
627+
628+
def float16_vector_or_none(x: Optional[bytes]) -> Optional[Any]:
629+
"""
630+
Convert value to float16 array.
631+
632+
Parameters
633+
----------
634+
x : bytes or None
635+
Little-endian block of bytes.
636+
637+
Returns
638+
-------
639+
float16 numpy array
640+
If input value is not None and numpy is installed
641+
float Python list
642+
If input value is not None and numpy is not installed
643+
None
644+
If input value is None
645+
646+
"""
647+
if x is None:
648+
return None
649+
650+
if has_numpy:
651+
return numpy.frombuffer(x, dtype=numpy.float16)
652+
653+
return struct.unpack(f'<{len(x)//2}e', x)
654+
655+
600656
def float64_vector_json_or_none(x: Optional[str]) -> Optional[Any]:
601657
"""
602658
Covert value to float64 array.
@@ -941,10 +997,12 @@ def bson_or_none(x: Optional[bytes]) -> Optional[Any]:
941997
2004: int16_vector_json_or_none,
942998
2005: int32_vector_json_or_none,
943999
2006: int64_vector_json_or_none,
1000+
2007: float16_vector_json_or_none,
9441001
3001: float32_vector_or_none,
9451002
3002: float64_vector_or_none,
9461003
3003: int8_vector_or_none,
9471004
3004: int16_vector_or_none,
9481005
3005: int32_vector_or_none,
9491006
3006: int64_vector_or_none,
1007+
3007: float16_vector_or_none,
9501008
}

singlestoredb/mysql/connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,14 @@
110110
FIELD_TYPE.INT16_VECTOR_JSON,
111111
FIELD_TYPE.INT32_VECTOR_JSON,
112112
FIELD_TYPE.INT64_VECTOR_JSON,
113+
FIELD_TYPE.FLOAT16_VECTOR_JSON,
113114
FIELD_TYPE.FLOAT32_VECTOR,
114115
FIELD_TYPE.FLOAT64_VECTOR,
115116
FIELD_TYPE.INT8_VECTOR,
116117
FIELD_TYPE.INT16_VECTOR,
117118
FIELD_TYPE.INT32_VECTOR,
118119
FIELD_TYPE.INT64_VECTOR,
120+
FIELD_TYPE.FLOAT16_VECTOR,
119121
}
120122

121123
UNSET = 'unset'

singlestoredb/mysql/constants/FIELD_TYPE.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
INT16_VECTOR_JSON = 2004
4141
INT32_VECTOR_JSON = 2005
4242
INT64_VECTOR_JSON = 2006
43+
FLOAT16_VECTOR_JSON = 2007
4344
FLOAT32_VECTOR = 3001
4445
FLOAT64_VECTOR = 3002
4546
INT8_VECTOR = 3003
4647
INT16_VECTOR = 3004
4748
INT32_VECTOR = 3005
4849
INT64_VECTOR = 3006
50+
FLOAT16_VECTOR = 3007

singlestoredb/mysql/constants/VECTOR_TYPE.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
INT16 = 4
55
INT32 = 5
66
INT64 = 6
7+
FLOAT16 = 7

singlestoredb/mysql/protocol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ def _parse_field_descriptor(self, encoding):
318318
self.type_code = FIELD_TYPE.INT64_VECTOR
319319
else:
320320
self.type_code = FIELD_TYPE.INT64_VECTOR_JSON
321+
elif vec_type == VECTOR_TYPE.FLOAT16:
322+
if self.charsetnr == 63:
323+
self.type_code = FIELD_TYPE.FLOAT16_VECTOR
324+
else:
325+
self.type_code = FIELD_TYPE.FLOAT16_VECTOR_JSON
321326
else:
322327
raise TypeError(f'unrecognized vector data type: {vec_type}')
323328
else:

singlestoredb/tests/test.sql

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,14 @@ INSERT INTO i64_vectors VALUES(1, '[1, 2, 3]');
676676
INSERT INTO i64_vectors VALUES(2, '[4, 5, 6]');
677677
INSERT INTO i64_vectors VALUES(3, '[-1, -4, 8]');
678678

679+
CREATE TABLE IF NOT EXISTS `f16_vectors` (
680+
id INT(11),
681+
a VECTOR(3, F16)
682+
);
683+
INSERT INTO f16_vectors VALUES(1, '[0.267, 0.535, 0.802]');
684+
INSERT INTO f16_vectors VALUES(2, '[0.371, 0.557, 0.743]');
685+
INSERT INTO f16_vectors VALUES(3, '[-0.424, -0.566, 0.707]');
686+
679687

680688
--
681689
-- Boolean test data for UDF testing

singlestoredb/tests/test_connection.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3097,6 +3097,40 @@ def test_i64_vectors(self):
30973097
np.array([-1, -4, 8], dtype=np.int64),
30983098
)
30993099

3100+
def test_f16_vectors(self):
3101+
if self.conn.driver in ['http', 'https']:
3102+
self.skipTest('Data API does not surface vector information')
3103+
3104+
self.cur.execute('show variables like "enable_extended_types_metadata"')
3105+
out = list(self.cur)
3106+
if not out or out[0][1].lower() == 'off':
3107+
self.skipTest('Database engine does not support extended types metadata')
3108+
3109+
self.cur.execute('select a from f16_vectors order by id')
3110+
out = list(self.cur)
3111+
3112+
if hasattr(out[0][0], 'dtype'):
3113+
assert out[0][0].dtype is np.dtype('float16')
3114+
assert out[1][0].dtype is np.dtype('float16')
3115+
assert out[2][0].dtype is np.dtype('float16')
3116+
3117+
# Float16 has ~3 decimal digits precision, use lower tolerance
3118+
np.testing.assert_array_almost_equal(
3119+
out[0][0],
3120+
np.array([0.267, 0.535, 0.802], dtype=np.float16),
3121+
decimal=2,
3122+
)
3123+
np.testing.assert_array_almost_equal(
3124+
out[1][0],
3125+
np.array([0.371, 0.557, 0.743], dtype=np.float16),
3126+
decimal=2,
3127+
)
3128+
np.testing.assert_array_almost_equal(
3129+
out[2][0],
3130+
np.array([-0.424, -0.566, 0.707], dtype=np.float16),
3131+
decimal=2,
3132+
)
3133+
31003134

31013135
if __name__ == '__main__':
31023136
import nose2

0 commit comments

Comments
 (0)