@@ -19,7 +19,29 @@ def test_init_invalid_top_k(self):
1919 SentenceTransformersSimilarityRanker (top_k = - 1 )
2020
2121 @patch ("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder" )
22- def test_init_onnx_backend (self , mocked_cross_encoder ):
22+ def test_init_warm_up_torch_backend (self , mocked_cross_encoder ):
23+ ranker = SentenceTransformersSimilarityRanker (
24+ model = "sentence-transformers/all-MiniLM-L6-v2" ,
25+ token = None ,
26+ device = ComponentDevice .from_str ("cpu" ),
27+ backend = "torch" ,
28+ trust_remote_code = True ,
29+ )
30+
31+ ranker .warm_up ()
32+ mocked_cross_encoder .assert_called_once_with (
33+ model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2" ,
34+ device = "cpu" ,
35+ token = None ,
36+ trust_remote_code = True ,
37+ model_kwargs = None ,
38+ tokenizer_kwargs = None ,
39+ config_kwargs = None ,
40+ backend = "torch" ,
41+ )
42+
43+ @patch ("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder" )
44+ def test_init_warm_up_onnx_backend (self , mocked_cross_encoder ):
2345 onnx_ranker = SentenceTransformersSimilarityRanker (
2446 model = "sentence-transformers/all-MiniLM-L6-v2" ,
2547 token = None ,
@@ -32,14 +54,15 @@ def test_init_onnx_backend(self, mocked_cross_encoder):
3254 model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2" ,
3355 device = "cpu" ,
3456 token = None ,
57+ trust_remote_code = False ,
3558 model_kwargs = None ,
3659 tokenizer_kwargs = None ,
3760 config_kwargs = None ,
3861 backend = "onnx" ,
3962 )
4063
4164 @patch ("haystack.components.rankers.sentence_transformers_similarity.CrossEncoder" )
42- def test_init_openvino_backend (self , mocked_cross_encoder ):
65+ def test_init_warm_up_openvino_backend (self , mocked_cross_encoder ):
4366 openvino_ranker = SentenceTransformersSimilarityRanker (
4467 model = "sentence-transformers/all-MiniLM-L6-v2" ,
4568 token = None ,
@@ -52,6 +75,7 @@ def test_init_openvino_backend(self, mocked_cross_encoder):
5275 model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2" ,
5376 device = "cpu" ,
5477 token = None ,
78+ trust_remote_code = False ,
5579 model_kwargs = None ,
5680 tokenizer_kwargs = None ,
5781 config_kwargs = None ,
@@ -74,6 +98,7 @@ def test_to_dict(self):
7498 "embedding_separator" : "\n " ,
7599 "scale_score" : True ,
76100 "score_threshold" : None ,
101+ "trust_remote_code" : False ,
77102 "model_kwargs" : None ,
78103 "tokenizer_kwargs" : None ,
79104 "config_kwargs" : None ,
@@ -92,6 +117,7 @@ def test_to_dict_with_custom_init_parameters(self):
92117 document_prefix = "document_instruction: " ,
93118 scale_score = False ,
94119 score_threshold = 0.01 ,
120+ trust_remote_code = True ,
95121 model_kwargs = {"torch_dtype" : torch .float16 },
96122 tokenizer_kwargs = {"model_max_length" : 512 },
97123 batch_size = 32 ,
@@ -110,6 +136,7 @@ def test_to_dict_with_custom_init_parameters(self):
110136 "embedding_separator" : "\n " ,
111137 "scale_score" : False ,
112138 "score_threshold" : 0.01 ,
139+ "trust_remote_code" : True ,
113140 "model_kwargs" : {"torch_dtype" : "torch.float16" },
114141 "tokenizer_kwargs" : {"model_max_length" : 512 },
115142 "config_kwargs" : None ,
@@ -141,6 +168,7 @@ def test_to_dict_with_quantization_options(self):
141168 "embedding_separator" : "\n " ,
142169 "scale_score" : True ,
143170 "score_threshold" : None ,
171+ "trust_remote_code" : False ,
144172 "model_kwargs" : {
145173 "load_in_4bit" : True ,
146174 "bnb_4bit_use_double_quant" : True ,
@@ -168,6 +196,7 @@ def test_from_dict(self):
168196 "embedding_separator" : "\n " ,
169197 "scale_score" : False ,
170198 "score_threshold" : 0.01 ,
199+ "trust_remote_code" : False ,
171200 "model_kwargs" : {"torch_dtype" : "torch.float16" },
172201 "tokenizer_kwargs" : None ,
173202 "config_kwargs" : None ,
@@ -187,6 +216,7 @@ def test_from_dict(self):
187216 assert component .embedding_separator == "\n "
188217 assert not component .scale_score
189218 assert component .score_threshold == 0.01
219+ assert component .trust_remote_code is False
190220 assert component .model_kwargs == {"torch_dtype" : torch .float16 }
191221 assert component .tokenizer_kwargs is None
192222 assert component .config_kwargs is None
@@ -209,6 +239,7 @@ def test_from_dict_no_default_parameters(self):
209239 assert component .embedding_separator == "\n "
210240 assert component .scale_score
211241 assert component .score_threshold is None
242+ assert component .trust_remote_code is False
212243 assert component .model_kwargs is None
213244 assert component .tokenizer_kwargs is None
214245 assert component .config_kwargs is None
0 commit comments