@@ -539,8 +539,10 @@ def make_encoder_decoder_cache(
539539
540540def make_mamba_cache (
541541 key_value_pairs : List [Tuple [torch .Tensor , torch .Tensor ]],
542+ cls_layers : Optional [Union [str , List [type ]]] = None ,
543+ cls_kwargs : Optional [Union [Dict [str , int ], List [Dict [str , int ]]]] = None ,
542544) -> "MambaCache" : # noqa: F821
543- "Creates a ``MambaCache``."
545+ """ Creates a ``MambaCache``. `cls_layers`, `cls_kwargs` are unused."" "
544546 # import is moved here because this part is slow.
545547 try :
546548 from transformers .models .mamba .modeling_mamba import MambaCache
@@ -591,8 +593,13 @@ def get_text_config(self, *args, **kwargs):
591593
592594 def make_sliding_window_cache (
593595 key_value_pairs : Union [List [torch .Tensor ], List [Tuple [torch .Tensor , torch .Tensor ]]],
596+ cls_layers : Optional [Union [str , List [type ]]] = None ,
597+ cls_kwargs : Optional [Union [Dict [str , int ], List [Dict [str , int ]]]] = None ,
594598 ) -> transformers .cache_utils .SlidingWindowCache :
595- "Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
599+ """
600+ Creates a :class:`transformers.cache_utils.SlidingWindowCache`.
601+ `cls_layers`, `cls_kwargs` are unused.
602+ """
596603 key_value_pairs = _preprocess_key_value_pairs (key_value_pairs )
597604
598605 class _config :
@@ -654,6 +661,8 @@ def make_hybrid_cache(
654661 max_cache_len : Optional [int ] = None ,
655662 max_batch_size : Optional [int ] = None ,
656663 sliding_window : Optional [int ] = None ,
664+ cls_layers : Optional [Union [str , List [type ]]] = None ,
665+ cls_kwargs : Optional [Union [Dict [str , int ], List [Dict [str , int ]]]] = None ,
657666 ) -> transformers .cache_utils .HybridCache :
658667 """
659668 Creates an instance of :class:`transformers.cache_utils.HybridCache`.
@@ -662,6 +671,8 @@ def make_hybrid_cache(
662671 :param key_value_pairs: list of pairs of (key, values)
663672 :return: :class:`transformers.cache_utils.HybridCache`
664673
674+ `cls_layers`, `cls_kwargs` are unused.
675+
665676 Example:
666677
667678 .. runpython::
@@ -742,16 +753,22 @@ def make_hybrid_cache(
742753 not max_batch_size and not max_cache_len
743754 ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
744755 max_batch_size = key_value_pairs [0 ][0 ].shape [0 ]
756+ assert max_cache_len is not None or all (
757+ isinstance (kv [0 ].shape [2 ], int ) for kv in key_value_pairs
758+ ), (
759+ f"Cannot determine max_cache_len with "
760+ f"shapes={ [kv [0 ].shape for kv in key_value_pairs ]} "
761+ )
745762 sets_of_dim = set (kv [0 ].shape [2 ] for kv in key_value_pairs )
746763 if len (sets_of_dim ) == 1 :
747- max_cache_len = sets_of_dim . pop ()
748- sliding_window = max_cache_len
764+ if max_cache_len is None :
765+ max_cache_len = sets_of_dim . pop ()
749766 else :
750767 assert (
751768 len (sets_of_dim ) == 2
752769 ), f"Not implemented for more than 2 dimensions { sets_of_dim } "
753- max_cache_len = max ( sets_of_dim )
754- sliding_window = min (sets_of_dim )
770+ if max_cache_len is None :
771+ max_cache_len = max (sets_of_dim )
755772 layer_types = [
756773 "full_attention" if i == max_cache_len else "sliding_attention"
757774 for i in [kv [0 ].shape [2 ] for kv in key_value_pairs ]
@@ -760,8 +777,8 @@ def make_hybrid_cache(
760777 assert (
761778 max_batch_size and max_cache_len
762779 ), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
763- if sliding_window is None :
764- sliding_window = max_cache_len
780+ if sliding_window is None :
781+ sliding_window = max_cache_len
765782 _max_cache_len = max_cache_len
766783 _sliding_window = sliding_window
767784
0 commit comments