@@ -272,8 +272,8 @@ def __init__(self, store_dir: str, submodule_name: str = None):
272272 self ._range_to_shard_idx = np .cumsum ([0 ] + [s .shape [0 ] for s in self .shards ])
273273 if "store_tokens" in self .config and self .config ["store_tokens" ]:
274274 self ._tokens = th .load (
275- os .path .join (store_dir , "tokens.pt" ), weights_only = True
276- ). cpu ()
275+ os .path .join (store_dir , "tokens.pt" ), weights_only = True , map_location = th . device ( "cpu" )
276+ )
277277
278278 self ._sequence_ranges = None
279279 self ._mean = None
@@ -753,10 +753,27 @@ class PairedActivationCache:
753753 def __init__ (self , store_dir_1 : str , store_dir_2 : str , submodule_name : str = None ):
754754 self .activation_cache_1 = ActivationCache (store_dir_1 , submodule_name )
755755 self .activation_cache_2 = ActivationCache (store_dir_2 , submodule_name )
756- assert len (self .activation_cache_1 ) == len (self .activation_cache_2 )
757-
756+ if len (self .activation_cache_1 ) != len (self .activation_cache_2 ):
757+ min_len = min (len (self .activation_cache_1 ), len (self .activation_cache_2 ))
758+ assert self .activation_cache_1 .tokens is not None and self .activation_cache_2 .tokens is not None , "Caches have not the same length and tokens are not stored"
759+ assert torch .all (self .activation_cache_1 .tokens [:min_len ] == self .activation_cache_2 .tokens [:min_len ]), "Tokens do not match"
760+ self ._len = min_len
761+ print (f"Warning: Caches have not the same length and tokens are not stored. Using the first { min_len } tokens." )
762+ if len (self .activation_cache_1 ) > self ._len :
763+ self ._sequence_ranges = self .activation_cache_2 .sequence_ranges
764+ else :
765+ self ._sequence_ranges = self .activation_cache_1 .sequence_ranges
766+ else :
767+ assert len (self .activation_cache_1 ) == len (self .activation_cache_2 ), f"Lengths do not match: { len (self .activation_cache_1 )} != { len (self .activation_cache_2 )} "
768+ self ._len = len (self .activation_cache_1 )
769+
770+ if self .activation_cache_1 .tokens is not None and self .activation_cache_2 .tokens is not None :
771+ assert torch .all (self .activation_cache_1 .tokens [:self ._len ] == self .activation_cache_2 .tokens [:self ._len ]), "Tokens do not match"
772+
773+
774+
758775 def __len__ (self ):
759- return len ( self .activation_cache_1 )
776+ return self ._len
760777
761778 def __getitem__ (self , index ):
762779 if isinstance (index , slice ):
@@ -776,17 +793,11 @@ def __getitem__(self, index):
776793
777794 @property
778795 def tokens (self ):
779- return th .stack (
780- (self .activation_cache_1 .tokens , self .activation_cache_2 .tokens ), dim = 0
781- )
796+ return th .stack ((self .activation_cache_1 .tokens [:self ._len ], self .activation_cache_2 .tokens [:self ._len ]), dim = 0 )
782797
783798 @property
784799 def sequence_ranges (self ):
785- seq_starts_1 = self .activation_cache_1 .sequence_ranges
786- seq_starts_2 = self .activation_cache_2 .sequence_ranges
787- if seq_starts_1 is not None and seq_starts_2 is not None :
788- return th .stack ((seq_starts_1 , seq_starts_2 ), dim = 0 )
789- return None
800+ return self ._sequence_ranges
790801
791802 @property
792803 def mean (self ):
0 commit comments