@@ -207,12 +207,40 @@ def test_load_pretrained_quantized(
207207
208208 # Load the model back from the same path
209209 loaded_model = StaticModel .from_pretrained (save_path , quantize_to = "float32" )
210-
211210 # Assert that the loaded model has the same properties as the original one
212211 assert loaded_model .embedding .dtype == np .float32
213212 assert loaded_model .embedding .shape == mock_vectors .shape
214213
215214
215+ def test_load_pretrained_dim (
216+ tmp_path : Path , mock_vectors : np .ndarray , mock_tokenizer : Tokenizer , mock_config : dict [str , str ]
217+ ) -> None :
218+ """Test loading a pretrained model with dimensionality."""
219+ # Save the model to a temporary path
220+ model = StaticModel (vectors = mock_vectors , tokenizer = mock_tokenizer , config = mock_config )
221+ save_path = tmp_path / "saved_model"
222+ model .save_pretrained (save_path )
223+
224+ loaded_model = StaticModel .from_pretrained (save_path , dimensionality = 2 )
225+
226+ # Assert that the loaded model has the same properties as the original one
227+ np .testing .assert_array_equal (loaded_model .embedding , mock_vectors [:, :2 ])
228+ assert loaded_model .tokenizer .get_vocab () == mock_tokenizer .get_vocab ()
229+ assert loaded_model .config == mock_config
230+
231+ # Load the model back from the same path
232+ loaded_model = StaticModel .from_pretrained (save_path , dimensionality = None )
233+
234+ # Assert that the loaded model has the same properties as the original one
235+ np .testing .assert_array_equal (loaded_model .embedding , mock_vectors )
236+ assert loaded_model .tokenizer .get_vocab () == mock_tokenizer .get_vocab ()
237+ assert loaded_model .config == mock_config
238+
239+ # Load the model back from the same path
240+ with pytest .raises (ValueError ):
241+ StaticModel .from_pretrained (save_path , dimensionality = 3000 )
242+
243+
216244def test_initialize_normalize (mock_vectors : np .ndarray , mock_tokenizer : Tokenizer ) -> None :
217245 """Tests whether the normalization initialization is correct."""
218246 model = StaticModel (mock_vectors , mock_tokenizer , {}, normalize = None )
0 commit comments