diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a71dbbb..006f29a 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,6 +11,7 @@ jobs: steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v7 + - run: sudo locale-gen fr_FR.UTF-8 - run: ./scripts/vendor.sh - run: make loadable static - run: uv sync --directory tests @@ -75,6 +76,7 @@ jobs: steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v7 + - run: sudo locale-gen fr_FR.UTF-8 - run: ./scripts/vendor.sh - run: make loadable static - run: uv sync --directory tests diff --git a/sqlite-vec.c b/sqlite-vec.c index 7af3b6a..58f5088 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -69,6 +69,80 @@ typedef size_t usize; #define countof(x) (sizeof(x) / sizeof((x)[0])) #define min(a, b) (((a) <= (b)) ? (a) : (b)) +// Locale-independent strtod for parsing JSON floats. +// strtod(3) respects LC_NUMERIC, so "1.5" fails in French/German locales +// where the decimal separator is a comma. This parser always uses '.'. +static double strtod_c(const char *str, char **endptr) { + const char *p = str; + double result = 0.0; + int sign = 1; + int has_digits = 0; + + while (*p == ' ' || *p == '\t' || *p == '\n' || *p == '\r') { + p++; + } + + if (*p == '-') { + sign = -1; + p++; + } else if (*p == '+') { + p++; + } + + while (*p >= '0' && *p <= '9') { + result = result * 10.0 + (*p - '0'); + p++; + has_digits = 1; + } + + if (*p == '.') { + double fraction = 0.0; + double divisor = 1.0; + p++; + while (*p >= '0' && *p <= '9') { + fraction = fraction * 10.0 + (*p - '0'); + divisor *= 10.0; + p++; + has_digits = 1; + } + result += fraction / divisor; + } + + if ((*p == 'e' || *p == 'E') && has_digits) { + int exp_sign = 1; + int exponent = 0; + p++; + if (*p == '-') { + exp_sign = -1; + p++; + } else if (*p == '+') { + p++; + } + while (*p >= '0' && *p <= '9') { + exponent = exponent * 10 + (*p - '0'); + p++; + } + if (exponent > 0) { + double exp_mult = pow(10.0, (double)exponent); + if (exp_sign == 1) { + result *= exp_mult; + } else { + result /= exp_mult; + } + } + } + + if (endptr) { + *endptr = (char *)(has_digits ? p : str); + } + + if (result == HUGE_VAL || result == -HUGE_VAL) { + errno = ERANGE; + } + + return sign * result; +} + #ifndef SQLITE_VEC_ENABLE_RESCORE #define SQLITE_VEC_ENABLE_RESCORE 1 #endif @@ -1043,7 +1117,7 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector, char *endptr; errno = 0; - double result = strtod(ptr, &endptr); + double result = strtod_c(ptr, &endptr); if ((errno != 0 && result == 0) // some interval error? || (errno == ERANGE && (result == HUGE_VAL || result == -HUGE_VAL)) // too big / smalls diff --git a/tests/test-locale.py b/tests/test-locale.py new file mode 100644 index 0000000..ec20541 --- /dev/null +++ b/tests/test-locale.py @@ -0,0 +1,46 @@ +"""Tests for locale-independent JSON float parsing (issue #241). + +strtod(3) respects LC_NUMERIC, so locales like fr_FR or de_DE that use ',' +as the decimal separator would cause vec0 to reject valid JSON vectors like +'[0.1, 0.2, 0.3]'. The fix replaces strtod with a custom locale-independent +parser (strtod_c) that always treats '.' as the decimal separator. +""" + +import locale +import struct +import pytest +from helpers import _f32 + + +def test_vec0_locale_independent(db): + db.execute("create virtual table v using vec0(embedding float[3])") + + original = locale.setlocale(locale.LC_NUMERIC) + locale_changed = False + + for candidate in ("fr_FR.UTF-8", "de_DE.UTF-8", "it_IT.UTF-8", "fr_FR", "de_DE"): + try: + locale.setlocale(locale.LC_NUMERIC, candidate) + locale_changed = True + break + except locale.Error: + continue + + try: + db.execute("insert into v(rowid, embedding) values (1, '[0.1, 0.2, 0.3]')") + db.execute("insert into v(rowid, embedding) values (2, '[1.23, 4.56, 7.89]')") + db.execute("insert into v(rowid, embedding) values (3, '[1e-3, 2.5e2, -0.75]')") + + row = db.execute("select embedding from v where rowid = 1").fetchone() + assert row[0] == _f32([0.1, 0.2, 0.3]) + + row = db.execute("select embedding from v where rowid = 2").fetchone() + assert row[0] == _f32([1.23, 4.56, 7.89]) + + row = db.execute("select embedding from v where rowid = 3").fetchone() + assert row[0] == _f32([1e-3, 2.5e2, -0.75]) + finally: + locale.setlocale(locale.LC_NUMERIC, original) + + if not locale_changed: + pytest.skip("No non-C locale available to test locale independence")