Skip to content

Commit f64fb83

Browse files
committed
cache
1 parent 90679ba commit f64fb83

1 file changed

Lines changed: 12 additions & 4 deletions

File tree

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,10 @@ def make_encoder_decoder_cache(
539539

540540
def make_mamba_cache(
541541
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
542+
cls_layers: Optional[Union[str, List[type]]] = None,
543+
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
542544
) -> "MambaCache": # noqa: F821
543-
"Creates a ``MambaCache``."
545+
"""Creates a ``MambaCache``. `cls_layers`, `cls_kwargs` are unused."""
544546
# import is moved here because this part is slow.
545547
try:
546548
from transformers.models.mamba.modeling_mamba import MambaCache
@@ -591,8 +593,13 @@ def get_text_config(self, *args, **kwargs):
591593

592594
def make_sliding_window_cache(
593595
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
596+
cls_layers: Optional[Union[str, List[type]]] = None,
597+
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
594598
) -> transformers.cache_utils.SlidingWindowCache:
595-
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
599+
"""
600+
Creates a :class:`transformers.cache_utils.SlidingWindowCache`.
601+
`cls_layers`, `cls_kwargs` are unused.
602+
"""
596603
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
597604

598605
class _config:
@@ -654,7 +661,8 @@ def make_hybrid_cache(
654661
max_cache_len: Optional[int] = None,
655662
max_batch_size: Optional[int] = None,
656663
sliding_window: Optional[int] = None,
657-
cls_layers: Optional[List[type]] = None,
664+
cls_layers: Optional[Union[str, List[type]]] = None,
665+
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
658666
) -> transformers.cache_utils.HybridCache:
659667
"""
660668
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
@@ -663,7 +671,7 @@ def make_hybrid_cache(
663671
:param key_value_pairs: list of pairs of (key, values)
664672
:return: :class:`transformers.cache_utils.HybridCache`
665673
666-
`cls_layers` is unused.
674+
`cls_layers`, `cls_kwargs` are unused.
667675
668676
Example:
669677

0 commit comments

Comments
 (0)