Skip to content

Commit 1cc541d

Browse files
authored
Merge pull request #1906 from weaviate/modules/cohere-reranker-baseurl
Add baseURL to cohere reranker
2 parents 1067ab8 + d5f56f0 commit 1cc541d

3 files changed

Lines changed: 19 additions & 5 deletions

File tree

integration/test_collection_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -913,9 +913,11 @@ def test_config_vector_index_hnsw_and_quantizer_pq(collection_factory: Collectio
913913
[
914914
(Configure.Reranker.cohere(), Rerankers.COHERE, {}),
915915
(
916-
Configure.Reranker.cohere(model="rerank-english-v2.0"),
916+
Configure.Reranker.cohere(
917+
model="rerank-english-v2.0", base_url="https://some-cohere-baseurl.ai/"
918+
),
917919
Rerankers.COHERE,
918-
{"model": "rerank-english-v2.0"},
920+
{"model": "rerank-english-v2.0", "baseURL": "https://some-cohere-baseurl.ai/"},
919921
),
920922
(Configure.Reranker.transformers(), Rerankers.TRANSFORMERS, {}),
921923
],

test/collection/test_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1098,10 +1098,11 @@ def test_config_with_generative(
10981098

10991099
TEST_CONFIG_WITH_RERANKER = [
11001100
(
1101-
Configure.Reranker.cohere(model="model"),
1101+
Configure.Reranker.cohere(model="model", base_url="https://some.base.url/"),
11021102
{
11031103
"reranker-cohere": {
11041104
"model": "model",
1105+
"baseURL": "https://some.base.url/",
11051106
},
11061107
},
11071108
),

weaviate/collections/classes/config.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515

1616
from deprecation import deprecated as docstring_deprecated
17-
from pydantic import AnyHttpUrl, Field, ValidationInfo, field_validator
17+
from pydantic import AnyHttpUrl, AnyUrl, Field, ValidationInfo, field_validator
1818
from typing_extensions import TypeAlias
1919
from typing_extensions import deprecated as typing_deprecated
2020

@@ -536,6 +536,13 @@ class _RerankerCohereConfig(_RerankerProvider):
536536
default=Rerankers.COHERE, frozen=True, exclude=True
537537
)
538538
model: Optional[Union[RerankerCohereModel, str]] = Field(default=None)
539+
baseURL: Optional[AnyHttpUrl]
540+
541+
def _to_dict(self) -> Dict[str, Any]:
542+
ret_dict = super()._to_dict()
543+
if self.baseURL is not None:
544+
ret_dict["baseURL"] = self.baseURL.unicode_string()
545+
return ret_dict
539546

540547

541548
class _RerankerCustomConfig(_RerankerProvider):
@@ -1259,6 +1266,7 @@ def custom(
12591266
@staticmethod
12601267
def cohere(
12611268
model: Optional[Union[RerankerCohereModel, str]] = None,
1269+
base_url: Optional[str] = None,
12621270
) -> _RerankerProvider:
12631271
"""Create a `_RerankerCohereConfig` object for use when reranking using the `reranker-cohere` module.
12641272
@@ -1267,8 +1275,11 @@ def cohere(
12671275
12681276
Args:
12691277
model: The model to use. Defaults to `None`, which uses the server-defined default
1278+
base_url: The base URL to send the reranker requests to. Defaults to `None`, which uses the server-defined default.
12701279
"""
1271-
return _RerankerCohereConfig(model=model)
1280+
return _RerankerCohereConfig(
1281+
model=model, baseURL=AnyUrl(base_url) if base_url is not None else None
1282+
)
12721283

12731284
@staticmethod
12741285
def jinaai(

0 commit comments

Comments
 (0)