|
4 | 4 | import onnx |
5 | 5 | import torch |
6 | 6 | import transformers |
7 | | -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_torch |
| 7 | +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout |
8 | 8 | from onnx_diagnostic.helpers import max_diff, string_type |
9 | 9 | from onnx_diagnostic.helpers.torch_helper import ( |
10 | 10 | dummy_llm, |
|
22 | 22 | from onnx_diagnostic.helpers.cache_helper import ( |
23 | 23 | make_dynamic_cache, |
24 | 24 | make_encoder_decoder_cache, |
25 | | - make_mamba_cache, |
26 | 25 | make_sliding_window_cache, |
27 | 26 | CacheKeyValue, |
28 | 27 | ) |
@@ -313,24 +312,6 @@ def test_torch_deepcopy_cache_dce(self): |
313 | 312 | self.assertEqual(hash1, hash2) |
314 | 313 | self.assertGreater(torch_tensor_size(cc), 1) |
315 | 314 |
|
316 | | - @requires_torch("4.50") |
317 | | - def test_torch_deepcopy_mamba_cache(self): |
318 | | - cache = make_mamba_cache( |
319 | | - [ |
320 | | - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), |
321 | | - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), |
322 | | - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), |
323 | | - ] |
324 | | - ) |
325 | | - at = torch_deepcopy(cache) |
326 | | - self.assertEqual(type(cache), type(at)) |
327 | | - self.assertEqual(max_diff(cache, at)["abs"], 0) |
328 | | - hash1 = string_type(at, with_shape=True, with_min_max=True) |
329 | | - cache.conv_states[0] += 1000 |
330 | | - hash2 = string_type(at, with_shape=True, with_min_max=True) |
331 | | - self.assertEqual(hash1, hash2) |
332 | | - self.assertGreater(torch_tensor_size(cache), 1) |
333 | | - |
334 | 315 | def test_torch_deepcopy_base_model_outputs(self): |
335 | 316 | bo = transformers.modeling_outputs.BaseModelOutput( |
336 | 317 | last_hidden_state=torch.rand((4, 4, 4)) |
|
0 commit comments