@@ -44,7 +44,11 @@ def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytr
4444 ca = CacheKeyValue (cache )
4545 flat = list (itertools .chain .from_iterable (zip (ca .key_cache , ca .value_cache )))
4646 unique = set (ca .cls_layers ) if ca .cls_layers else None
47- if unique is None or (len (unique ) == 1 and unique .pop ().__name__ == "DynamicLayer" ):
47+ if (
48+ cache .__class__ .__name__ != "DynamicCache"
49+ or unique is None
50+ or (len (unique ) == 1 and unique .pop ().__name__ == "DynamicLayer" )
51+ ):
4852 keys = list (
4953 itertools .chain .from_iterable (
5054 (f"key_{ i } " , f"value_{ i } " ) for i in range (len (ca .key_cache ))
@@ -80,14 +84,13 @@ def _unflatten_cache(
8084 )
8185 if expected == context :
8286 res = make_cache (list (zip (values [::2 ], values [1 ::2 ])))
83- assert output_type is None or isinstance (
84- res , output_type
85- ), f"Type mismatch between { output_type } (expected) and { type (res )} "
86- return res
87-
88- cls_layer_names = [SHORTEN_LAYER_NAMES [name .split ("_" )[1 ][0 ]] for name in context ][::2 ]
89- cls_layers = [getattr (transformers .cache_utils , cls_name ) for cls_name in cls_layer_names ]
90- res = make_cache (list (zip (values [::2 ], values [1 ::2 ])), cls_layers = cls_layers )
87+ else :
88+ cls_layer_names = [SHORTEN_LAYER_NAMES [name .split ("_" )[1 ][0 ]] for name in context ][::2 ]
89+ cls_layers = [
90+ getattr (transformers .cache_utils , cls_name ) for cls_name in cls_layer_names
91+ ]
92+ res = make_cache (list (zip (values [::2 ], values [1 ::2 ])), cls_layers = cls_layers )
93+
9194 assert output_type is None or isinstance (
9295 res , output_type
9396 ), f"Type mismatch between { output_type } (expected) and { type (res )} "
0 commit comments