diff --git a/haystack/components/rankers/sentence_transformers_similarity.py b/haystack/components/rankers/sentence_transformers_similarity.py index 82e6498140..2040da3ec1 100644 --- a/haystack/components/rankers/sentence_transformers_similarity.py +++ b/haystack/components/rankers/sentence_transformers_similarity.py @@ -52,6 +52,7 @@ def __init__( # noqa: PLR0913 embedding_separator: str = "\n", scale_score: bool = True, score_threshold: Optional[float] = None, + trust_remote_code: bool = False, model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, config_kwargs: Optional[Dict[str, Any]] = None, @@ -84,6 +85,9 @@ def __init__( # noqa: PLR0913 If `False`, disables scaling of the raw logit predictions. :param score_threshold: Use it to return documents with a score above this threshold only. + :param trust_remote_code: + If `False`, allows only Hugging Face verified model architectures. + If `True`, allows custom models and scripts. :param model_kwargs: Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained` when loading the model. Refer to specific model documentation for available kwargs. @@ -119,6 +123,7 @@ def __init__( # noqa: PLR0913 self.embedding_separator = embedding_separator self.scale_score = scale_score self.score_threshold = score_threshold + self.trust_remote_code = trust_remote_code self.model_kwargs = model_kwargs self.tokenizer_kwargs = tokenizer_kwargs self.config_kwargs = config_kwargs @@ -140,6 +145,7 @@ def warm_up(self) -> None: model_name_or_path=self.model, device=self.device.to_torch_str(), token=self.token.resolve_value() if self.token else None, + trust_remote_code=self.trust_remote_code, model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, config_kwargs=self.config_kwargs, @@ -165,6 +171,7 @@ def to_dict(self) -> Dict[str, Any]: embedding_separator=self.embedding_separator, scale_score=self.scale_score, score_threshold=self.score_threshold, + trust_remote_code=self.trust_remote_code, model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, config_kwargs=self.config_kwargs, diff --git a/releasenotes/notes/st-similarity-ranker-trust-remote-code-7e00abfc96afa698.yaml b/releasenotes/notes/st-similarity-ranker-trust-remote-code-7e00abfc96afa698.yaml new file mode 100644 index 0000000000..4e85328479 --- /dev/null +++ b/releasenotes/notes/st-similarity-ranker-trust-remote-code-7e00abfc96afa698.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Added a `trust_remote_code` parameter to the `SentenceTransformersSimilarityRanker` component. + When set to True, this enables execution of custom models and scripts hosted on the Hugging Face Hub. diff --git a/test/components/rankers/test_sentence_transformers_similarity.py b/test/components/rankers/test_sentence_transformers_similarity.py index 8d0836cc5e..d7e68589f5 100644 --- a/test/components/rankers/test_sentence_transformers_similarity.py +++ b/test/components/rankers/test_sentence_transformers_similarity.py @@ -19,7 +19,29 @@ def test_init_invalid_top_k(self): SentenceTransformersSimilarityRanker(top_k=-1) @patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder") - def test_init_onnx_backend(self, mocked_cross_encoder): + def test_init_warm_up_torch_backend(self, mocked_cross_encoder): + ranker = SentenceTransformersSimilarityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", + token=None, + device=ComponentDevice.from_str("cpu"), + backend="torch", + trust_remote_code=True, + ) + + ranker.warm_up() + mocked_cross_encoder.assert_called_once_with( + model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", + device="cpu", + token=None, + trust_remote_code=True, + model_kwargs=None, + tokenizer_kwargs=None, + config_kwargs=None, + backend="torch", + ) + + @patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder") + def test_init_warm_up_onnx_backend(self, mocked_cross_encoder): onnx_ranker = SentenceTransformersSimilarityRanker( model="sentence-transformers/all-MiniLM-L6-v2", token=None, @@ -32,6 +54,7 @@ def test_init_onnx_backend(self, mocked_cross_encoder): model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu", token=None, + trust_remote_code=False, model_kwargs=None, tokenizer_kwargs=None, config_kwargs=None, @@ -39,7 +62,7 @@ def test_init_onnx_backend(self, mocked_cross_encoder): ) @patch("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder") - def test_init_openvino_backend(self, mocked_cross_encoder): + def test_init_warm_up_openvino_backend(self, mocked_cross_encoder): openvino_ranker = SentenceTransformersSimilarityRanker( model="sentence-transformers/all-MiniLM-L6-v2", token=None, @@ -52,6 +75,7 @@ def test_init_openvino_backend(self, mocked_cross_encoder): model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu", token=None, + trust_remote_code=False, model_kwargs=None, tokenizer_kwargs=None, config_kwargs=None, @@ -74,6 +98,7 @@ def test_to_dict(self): "embedding_separator": "\n", "scale_score": True, "score_threshold": None, + "trust_remote_code": False, "model_kwargs": None, "tokenizer_kwargs": None, "config_kwargs": None, @@ -92,6 +117,7 @@ def test_to_dict_with_custom_init_parameters(self): document_prefix="document_instruction: ", scale_score=False, score_threshold=0.01, + trust_remote_code=True, model_kwargs={"torch_dtype": torch.float16}, tokenizer_kwargs={"model_max_length": 512}, batch_size=32, @@ -110,6 +136,7 @@ def test_to_dict_with_custom_init_parameters(self): "embedding_separator": "\n", "scale_score": False, "score_threshold": 0.01, + "trust_remote_code": True, "model_kwargs": {"torch_dtype": "torch.float16"}, "tokenizer_kwargs": {"model_max_length": 512}, "config_kwargs": None, @@ -141,6 +168,7 @@ def test_to_dict_with_quantization_options(self): "embedding_separator": "\n", "scale_score": True, "score_threshold": None, + "trust_remote_code": False, "model_kwargs": { "load_in_4bit": True, "bnb_4bit_use_double_quant": True, @@ -168,6 +196,7 @@ def test_from_dict(self): "embedding_separator": "\n", "scale_score": False, "score_threshold": 0.01, + "trust_remote_code": False, "model_kwargs": {"torch_dtype": "torch.float16"}, "tokenizer_kwargs": None, "config_kwargs": None, @@ -187,6 +216,7 @@ def test_from_dict(self): assert component.embedding_separator == "\n" assert not component.scale_score assert component.score_threshold == 0.01 + assert component.trust_remote_code is False assert component.model_kwargs == {"torch_dtype": torch.float16} assert component.tokenizer_kwargs is None assert component.config_kwargs is None @@ -209,6 +239,7 @@ def test_from_dict_no_default_parameters(self): assert component.embedding_separator == "\n" assert component.scale_score assert component.score_threshold is None + assert component.trust_remote_code is False assert component.model_kwargs is None assert component.tokenizer_kwargs is None assert component.config_kwargs is None