Skip to content

Commit abfcbfa

Browse files
authored
Fix HashModel list validation blocking vector fields (#785)
HashModel.__init_subclass__ rejected list fields unconditionally, but vector fields require list[float] type annotation. Now checks for vector_options before raising the error. Also adds automatic serialization of vector list[float] to bytes for Redis Hash storage, and deserialization back to list[float] on retrieval. Fixes #544
1 parent a134482 commit abfcbfa

2 files changed

Lines changed: 126 additions & 0 deletions

File tree

aredis_om/model/model.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import logging
66
import operator
7+
import struct
78
from copy import copy
89
from enum import Enum
910
from functools import reduce
@@ -323,6 +324,68 @@ def convert_base64_to_bytes(obj, model_fields):
323324
return obj
324325

325326

327+
def convert_vector_to_bytes(obj, model_fields):
328+
"""Convert list[float] vector fields to packed bytes for HashModel storage.
329+
330+
Redis Hash fields can only store scalar values (strings, bytes, numbers).
331+
Vector fields (list[float]) need to be serialized to bytes for storage.
332+
This uses little-endian float32 packing, matching the format expected by
333+
RediSearch for vector similarity queries.
334+
"""
335+
if not isinstance(obj, dict):
336+
return obj
337+
338+
result = {}
339+
for key, value in obj.items():
340+
if key in model_fields and isinstance(value, list):
341+
field_info = model_fields[key]
342+
vector_options = getattr(field_info, "vector_options", None)
343+
if vector_options is not None and value:
344+
# Pack floats as little-endian float32 bytes
345+
try:
346+
result[key] = struct.pack(f"<{len(value)}f", *value)
347+
except struct.error:
348+
# If packing fails, keep original value
349+
result[key] = value
350+
else:
351+
result[key] = value
352+
else:
353+
result[key] = value
354+
return result
355+
356+
357+
def convert_bytes_to_vector(obj, model_fields):
358+
"""Convert packed bytes back to list[float] for vector fields.
359+
360+
This reverses the conversion done by convert_vector_to_bytes.
361+
"""
362+
if not isinstance(obj, dict):
363+
return obj
364+
365+
result = {}
366+
for key, value in obj.items():
367+
if key in model_fields:
368+
field_info = model_fields[key]
369+
vector_options = getattr(field_info, "vector_options", None)
370+
if vector_options is not None and isinstance(value, (bytes, str)):
371+
# Handle bytes or string (Redis may return as string with decode_responses)
372+
try:
373+
if isinstance(value, str):
374+
# If decode_responses=True, we get a string - need to encode back
375+
value = value.encode("latin-1")
376+
# Unpack little-endian float32 bytes
377+
num_floats = len(value) // 4
378+
result[key] = list(struct.unpack(f"<{num_floats}f", value))
379+
except (struct.error, ValueError, UnicodeEncodeError):
380+
# If unpacking fails, keep original value
381+
result[key] = value
382+
else:
383+
result[key] = value
384+
else:
385+
result[key] = value
386+
return result
387+
388+
326389
class PartialModel:
327390
"""A partial model instance that only contains certain fields.
328391
@@ -2834,11 +2897,31 @@ class HashModel(RedisModel, abc.ABC):
28342897
def __init_subclass__(cls, **kwargs):
28352898
super().__init_subclass__(**kwargs)
28362899

2900+
# Helper to check if a field has vector_options (making it a vector field).
2901+
# We check cls.__dict__ because model_fields may not be populated yet
2902+
# when __init_subclass__ runs during class creation.
2903+
def _has_vector_options(field_name: str) -> bool:
2904+
"""Check if a field has vector_options set, making it a vector field."""
2905+
# First check cls.__dict__ for the original FieldInfo (before Pydantic processing)
2906+
if field_name in cls.__dict__:
2907+
field = cls.__dict__[field_name]
2908+
if getattr(field, "vector_options", None) is not None:
2909+
return True
2910+
# Also check model_fields in case it's populated
2911+
if hasattr(cls, "model_fields") and field_name in cls.model_fields:
2912+
field = cls.model_fields[field_name]
2913+
if getattr(field, "vector_options", None) is not None:
2914+
return True
2915+
return False
2916+
28372917
if hasattr(cls, "__annotations__"):
28382918
for name, field_type in cls.__annotations__.items():
28392919
origin = get_origin(field_type)
28402920
for typ in (Set, Mapping, List):
28412921
if isinstance(origin, type) and issubclass(origin, typ):
2922+
# Vector fields are allowed to be lists (list[float])
2923+
if _has_vector_options(name):
2924+
continue
28422925
raise RedisModelError(
28432926
f"HashModels cannot index set, list, "
28442927
f"or mapping fields. Field: {name}"
@@ -2860,6 +2943,9 @@ def __init_subclass__(cls, **kwargs):
28602943
if origin:
28612944
for typ in (Set, Mapping, List):
28622945
if issubclass(origin, typ):
2946+
# Vector fields are allowed to be lists (list[float])
2947+
if getattr(field, "vector_options", None) is not None:
2948+
continue
28632949
raise RedisModelError(
28642950
f"HashModels cannot index set, list, "
28652951
f"or mapping fields. Field: {name}"
@@ -2944,6 +3030,8 @@ async def save(
29443030
# Get model data and apply conversions in the correct order
29453031
document = self.model_dump()
29463032
document = convert_datetime_to_timestamp(document)
3033+
# Convert vector fields (list[float]) to bytes before base64 encoding
3034+
document = convert_vector_to_bytes(document, self.__class__.model_fields)
29473035
document = convert_bytes_to_base64(document)
29483036

29493037
# Then apply jsonable encoding for other types
@@ -3046,6 +3134,8 @@ async def get(cls: Type["Model"], pk: Any) -> "Model":
30463134
document = convert_timestamp_to_datetime(document, cls.model_fields)
30473135
# Convert base64 strings back to bytes for bytes fields
30483136
document = convert_base64_to_bytes(document, cls.model_fields)
3137+
# Convert bytes back to list[float] for vector fields
3138+
document = convert_bytes_to_vector(document, cls.model_fields)
30493139
result = cls.model_validate(document)
30503140
except TypeError as e:
30513141
log.warning(
@@ -3059,6 +3149,8 @@ async def get(cls: Type["Model"], pk: Any) -> "Model":
30593149
document = convert_timestamp_to_datetime(document, cls.model_fields)
30603150
# Convert base64 strings back to bytes for bytes fields
30613151
document = convert_base64_to_bytes(document, cls.model_fields)
3152+
# Convert bytes back to list[float] for vector fields
3153+
document = convert_bytes_to_vector(document, cls.model_fields)
30623154
result = cls.model_validate(document)
30633155
return result
30643156

tests/test_hash_model.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
QueryNotSupportedError,
2323
RedisModel,
2424
RedisModelError,
25+
VectorFieldOptions,
2526
)
2627
from aredis_om.model.model import ExpressionProxy
2728

@@ -1569,3 +1570,36 @@ class Meta:
15691570
assert retrieved.pk == 42
15701571
assert retrieved.x == 42
15711572
assert retrieved.name == "test"
1573+
1574+
1575+
@py_test_mark_asyncio
1576+
async def test_hashmodel_vector_field_with_list(key_prefix, redis):
1577+
"""Test that HashModel allows list[float] fields when used with vector_options.
1578+
1579+
Regression test for GitHub issue #544: HashModel rejected list fields
1580+
even when they were vector fields that require list[float] type.
1581+
"""
1582+
vector_options = VectorFieldOptions.flat(
1583+
type=VectorFieldOptions.TYPE.FLOAT32,
1584+
dimension=4,
1585+
distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE,
1586+
)
1587+
1588+
# This should NOT raise an error - vector fields are allowed to be lists
1589+
class VectorDocument(HashModel, index=True):
1590+
name: str
1591+
embedding: list[float] = Field(default=[], vector_options=vector_options)
1592+
1593+
class Meta:
1594+
global_key_prefix = key_prefix
1595+
database = redis
1596+
1597+
await Migrator().run()
1598+
1599+
# Create and save a document with a vector
1600+
doc = VectorDocument(name="test", embedding=[0.1, 0.2, 0.3, 0.4])
1601+
await doc.save()
1602+
1603+
# Retrieve and verify
1604+
retrieved = await VectorDocument.get(doc.pk)
1605+
assert retrieved.name == "test"

0 commit comments

Comments
 (0)