Skip to content

Commit 7b588c5

Browse files
authored
Merge pull request #1635 from weaviate/rob/muvera
Add muvera config
2 parents 8f61828 + 176def6 commit 7b588c5

4 files changed

Lines changed: 138 additions & 0 deletions

File tree

integration/test_collection_config.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,67 @@ def test_config_multi_vector_disabled(
14051405
assert conf.multi_vector is None
14061406

14071407

1408+
def test_config_muvera_enabled(
1409+
collection_factory: CollectionFactory,
1410+
) -> None:
1411+
dummy = collection_factory("dummy", ports=(8086, 50057))
1412+
if dummy._connection._weaviate_version.is_lower_than(1, 31, 0):
1413+
pytest.skip("Muvera is not supported in Weaviate versions lower than 1.31.0")
1414+
1415+
collection = collection_factory(
1416+
ports=(8086, 50057),
1417+
properties=[Property(name="name", data_type=DataType.TEXT)],
1418+
vectorizer_config=[
1419+
Configure.NamedVectors.text2colbert_jinaai(
1420+
name="vec",
1421+
vectorize_collection_name=False,
1422+
vector_index_config=Configure.VectorIndex.hnsw(
1423+
multi_vector=Configure.VectorIndex.MultiVector.multi_vector(
1424+
encoding=Configure.VectorIndex.MultiVector.Encoding.muvera()
1425+
)
1426+
),
1427+
)
1428+
],
1429+
)
1430+
config = collection.config.get()
1431+
assert config.vector_config is not None
1432+
conf = config.vector_config["vec"].vector_index_config
1433+
assert isinstance(conf, _VectorIndexConfigHNSW)
1434+
if collection._connection._weaviate_version.is_lower_than(1, 31, 0):
1435+
assert conf.multi_vector is None
1436+
else:
1437+
assert conf.multi_vector is not None
1438+
assert conf.multi_vector.encoding is not None
1439+
1440+
1441+
def test_config_muvera_disabled(
1442+
collection_factory: CollectionFactory,
1443+
) -> None:
1444+
dummy = collection_factory("dummy", ports=(8086, 50057))
1445+
if dummy._connection._weaviate_version.is_lower_than(1, 29, 0):
1446+
pytest.skip("Multivector is not supported in Weaviate versions lower than 1.29.0")
1447+
1448+
collection = collection_factory(
1449+
ports=(8086, 50057),
1450+
properties=[Property(name="name", data_type=DataType.TEXT)],
1451+
vectorizer_config=[
1452+
Configure.NamedVectors.text2colbert_jinaai(
1453+
name="vec",
1454+
vectorize_collection_name=False,
1455+
vector_index_config=Configure.VectorIndex.hnsw(
1456+
multi_vector=Configure.VectorIndex.MultiVector.multi_vector()
1457+
),
1458+
)
1459+
],
1460+
)
1461+
config = collection.config.get()
1462+
assert config.vector_config is not None
1463+
conf = config.vector_config["vec"].vector_index_config
1464+
assert isinstance(conf, _VectorIndexConfigHNSW)
1465+
assert conf.multi_vector is not None
1466+
assert conf.multi_vector.encoding is None
1467+
1468+
14081469
@pytest.mark.parametrize(
14091470
"generative_config",
14101471
[

weaviate/collections/classes/config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
)
3434
from weaviate.collections.classes.config_vector_index import (
3535
VectorFilterStrategy,
36+
_MuveraConfigCreate,
37+
_EncodingConfigCreate,
3638
_MultiVectorConfigCreate,
3739
_QuantizerConfigCreate,
3840
_VectorIndexConfigCreate,
@@ -1556,8 +1558,20 @@ class _SQConfig(_ConfigBase):
15561558
SQConfig = _SQConfig
15571559

15581560

1561+
@dataclass
1562+
class _MuveraConfig(_ConfigBase):
1563+
enabled: Optional[bool]
1564+
ksim: Optional[int]
1565+
dprojections: Optional[int]
1566+
repetitions: Optional[int]
1567+
1568+
1569+
MuveraConfig = _MuveraConfig
1570+
1571+
15591572
@dataclass
15601573
class _MultiVectorConfig(_ConfigBase):
1574+
encoding: Optional[_MuveraConfig]
15611575
aggregation: str
15621576

15631577

@@ -2031,12 +2045,31 @@ def __add_props(
20312045
ret_dict["properties"] = existing_props
20322046

20332047

2048+
class _VectorIndexMultivectorEncoding:
2049+
@staticmethod
2050+
def muvera(
2051+
ksim: Optional[int] = None,
2052+
dprojections: Optional[int] = None,
2053+
repetitions: Optional[int] = None,
2054+
) -> _EncodingConfigCreate:
2055+
return _MuveraConfigCreate(
2056+
enabled=True,
2057+
ksim=ksim,
2058+
dprojections=dprojections,
2059+
repetitions=repetitions,
2060+
)
2061+
2062+
20342063
class _VectorIndexMultiVector:
2064+
Encoding = _VectorIndexMultivectorEncoding
2065+
20352066
@staticmethod
20362067
def multi_vector(
2068+
encoding: Optional[_EncodingConfigCreate] = None,
20372069
aggregation: Optional[MultiVectorAggregation] = None,
20382070
) -> _MultiVectorConfigCreate:
20392071
return _MultiVectorConfigCreate(
2072+
encoding=encoding if encoding is not None else None,
20402073
aggregation=aggregation.value if aggregation is not None else None,
20412074
)
20422075

weaviate/collections/classes/config_methods.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_InvertedIndexConfig,
2222
_MultiTenancyConfig,
2323
_MultiVectorConfig,
24+
_MuveraConfig,
2425
_NamedVectorConfig,
2526
_NamedVectorizerConfig,
2627
_NestedProperty,
@@ -149,12 +150,31 @@ def __get_quantizer_config(
149150
return quantizer
150151

151152

153+
def __get_multivector_encoding(config: Dict[str, Any]) -> Optional[_MuveraConfig]:
154+
return (
155+
None
156+
if config.get("muvera") is None
157+
or not config.get("muvera", {"enabled": False}).get("enabled")
158+
else _MuveraConfig(
159+
enabled=config["muvera"]["enabled"],
160+
ksim=config["muvera"]["ksim"],
161+
dprojections=config["muvera"]["dprojections"],
162+
repetitions=config["muvera"]["repetitions"],
163+
)
164+
)
165+
166+
152167
def __get_multivector(config: Dict[str, Any]) -> Optional[_MultiVectorConfig]:
153168
return (
154169
None
155170
if config.get("multivector") is None
156171
or not config.get("multivector", {"enabled": False}).get("enabled")
157172
else _MultiVectorConfig(
173+
encoding=(
174+
None
175+
if config["multivector"].get("muvera") is None
176+
else __get_multivector_encoding(config["multivector"])
177+
),
158178
aggregation=config["multivector"]["aggregation"],
159179
)
160180
)
@@ -244,6 +264,7 @@ def __get_vector_config(
244264
),
245265
vector_index_config=vector_index_config,
246266
)
267+
247268
return named_vectors
248269
else:
249270
return None

weaviate/collections/classes/config_vector_index.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,26 @@ class _MultiVectorConfigCreateBase(_ConfigCreateModel):
4242
enabled: bool = Field(default=True)
4343

4444

45+
class _EncodingConfigCreate(_MultiVectorConfigCreateBase):
46+
enabled: bool = Field(default=True)
47+
48+
@staticmethod
49+
@abstractmethod
50+
def encoding_name() -> str: ...
51+
52+
53+
class _MuveraConfigCreate(_EncodingConfigCreate):
54+
ksim: Optional[int]
55+
dprojections: Optional[int]
56+
repetitions: Optional[int]
57+
58+
@staticmethod
59+
def encoding_name() -> str:
60+
return "muvera"
61+
62+
4563
class _MultiVectorConfigCreate(_MultiVectorConfigCreateBase):
64+
encoding: Optional[_EncodingConfigCreate] = Field(exclude=True)
4665
aggregation: Optional[str]
4766

4867

@@ -61,6 +80,10 @@ def _to_dict(self) -> Dict[str, Any]:
6180
ret_dict[self.quantizer.quantizer_name()] = self.quantizer._to_dict()
6281
if self.distance is not None:
6382
ret_dict["distance"] = str(self.distance.value)
83+
if self.multivector is not None and self.multivector.encoding is not None:
84+
ret_dict["multivector"][self.multivector.encoding.encoding_name()] = (
85+
self.multivector.encoding._to_dict()
86+
)
6487

6588
return ret_dict
6689

0 commit comments

Comments
 (0)